# pylint: disable=too-many-arguments
# pylint: disable=undefined-variable
# pylint: disable=no-member
""" Batch class for storing CT-scans. """
import os
import logging
import dill as pickle
import numpy as np
import aiofiles
import blosc
import dicom
import SimpleITK as sitk
from ..dataset import Batch, action, inbatch_parallel, any_action_failed, DatasetIndex # pylint: disable=no-name-in-module
from .resize import resize_scipy, resize_pil
from .segment import calc_lung_mask_numba
from .mip import make_xip_numba
from .flip import flip_patient_numba
from .crop import make_central_crop
from .patches import get_patches_numba, assemble_patches, calc_padding_size
from .rotate import rotate_3D
from .dump import dump_data
# logger initialization
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
AIR_HU = -2000
DARK_HU = -2000
[docs]class CTImagesBatch(Batch): # pylint: disable=too-many-public-methods
""" Batch class for storing batch of CT-scans in 3D.
Contains a component `images` = 3d-array of stacked scans
along number_of_slices (z) axis (aka "skyscraper"), associated information
for subsetting individual patient's 3D scan (_bounds, origin, spacing) and
various methods to preprocess the data.
Parameters
----------
index : dataset.index
ids of scans to be put in a batch
Attributes
----------
components : tuple of strings.
List names of data components of a batch, which are `images`,
`origin` and `spacing`.
NOTE: Implementation of this attribute is required by Base class.
index : dataset.index
represents indices of scans from a batch
images : ndarray
contains ct-scans for all patients in batch.
spacing : ndarray of floats
represents distances between pixels in world coordinates
origin : ndarray of floats
contains world coordinates of (0, 0, 0)-pixel of scans
"""
components = "images", "spacing", "origin"
def __init__(self, index, *args, **kwargs):
""" Execute Batch construction and init of basic attributes
Parameters
----------
index : Dataset.Index class.
Required indexing of objects (files).
"""
super().__init__(index, *args, **kwargs)
# init basic attrs
self.images = None
self._bounds = None
self.origin = None
self.spacing = None
self._init_data(spacing=np.ones(shape=(len(self), 3)),
origin=np.zeros(shape=(len(self), 3)),
bounds=np.array([], dtype='int'))
def _if_component_filled(self, component):
""" Check if component is filled with data.
Parameters
----------
component : str
component to be checked
Returns
-------
bool
True if filled, False if not.
"""
return getattr(self, component, None) is not None
def _init_data(self, bounds=None, **kwargs):
""" Initialize _bounds and components (images, origin, spacing).
`_init_data` is used as initializer of batch inner structures,
called inside __init__ and other methods
Parameters
----------
**kwargs
images : ndarray(n_patients * z, y, x) or None
data to be put as a component `images` in self.images, where
n_patients is total number of patients in array and `z, y, x`
is a shape of each patient 3D array.
Note, that each patient should have same and constant
`z, y, x` shape.
bounds : ndarray(n_patients, dtype=np.int) or None
1d-array of bound-floors for each scan 3D array,
has length = number of items in batch + 1, to be put in self._bounds.
origin : ndarray(n_patients, 3) or None
2d-array contains origin coordinates of patients scans
in `z,y,x`-format in world coordinates to be put in self.origin.
spacing : ndarray(n_patients, 3) or None
2d-array [number of items X 3] of spacings between slices
along each of `z,y,x` axes for each patient's 3D array
in world coordinates to be put in self.spacing.
"""
self._bounds = bounds if bounds is not None else self._bounds
for comp_name, comp_data in kwargs.items():
setattr(self, comp_name, comp_data)
[docs] @classmethod
def split(cls, batch, batch_size):
""" Split one batch in two batches.
The lens of 2 batches would be `batch_size` and `len(batch) - batch_size`
Parameters
----------
batch : Batch class instance
batch to be splitted in two
batch_size : int
length of first returned batch.
If batch_size >= len(batch), return None instead of a 2nd batch
Returns
-------
tuple of batches
(1st_Batch, 2nd_Batch)
Notes
-----
Method does not change the structure of input Batch.index. Indices of output
batches are simply subsets of input Batch.index.
"""
if batch_size == 0:
return (None, batch)
if batch_size >= len(batch):
return (batch, None)
# form indices for both batches
size_first, _ = batch_size, len(batch) - batch_size
ix_first = batch.index.create_subset(batch.indices[:size_first])
ix_second = batch.index.create_subset(batch.indices[size_first:])
# init batches
batches = cls(ix_first), cls(ix_second)
# put non-None components in batch-parts
for batch_part in batches:
for component in batch.components:
if getattr(batch, component) is not None:
comps = []
for ix in batch_part.indices:
# get component for a specific item defined by ix and put into the list
comp_pos = batch.get_pos(None, component, ix)
comp = getattr(batch, component)[comp_pos]
comps.append(comp)
# set the component for the whole batch-part
source = np.concatenate(comps)
setattr(batch_part, component, source)
else:
setattr(batch_part, component, None)
# set _bounds attrs if filled in batch
if len(batch._bounds) >= 2: # pylint: disable=protected-access
for batch_part in batches:
n_slices = []
for ix in batch_part.indices:
ix_pos_initial = batch.index.get_pos(ix)
n_slices.append(batch.upper_bounds[ix_pos_initial]
- batch.lower_bounds[ix_pos_initial])
# update _bounds in new batches
batch_part._bounds = np.cumsum([0] + n_slices, dtype=np.int) # pylint: disable=protected-access
return batches
[docs] @classmethod
def concat(cls, batches):
""" Concatenate several batches in one large batch.
Assume that same components are filled in all supplied batches.
Parameters
----------
batches : list or tuple of batches
sequence of batches to be concatenated
Returns
-------
batch
large batch with length = sum of lengths of concated batches
Notes
-----
Old batches' indexes are dropped. New large batch has new
np-arange index.
if None-entries or batches of len=0 are included in the list of batches,
they will be dropped after concat.
"""
# leave only non-empty batches
batches = [batch for batch in batches if batch is not None]
batches = [batch for batch in batches if len(batch) > 0]
if len(batches) == 0:
return None
# create index for the large batch and init batch
ixbatch = DatasetIndex(np.arange(np.sum([len(batch) for batch in batches])))
large_batch = cls(ixbatch)
# set non-none components in the large batch
for component in batches[0].components:
comps = None
if getattr(batches[0], component) is not None:
comps = np.concatenate([getattr(batch, component) for batch in batches])
setattr(large_batch, component, comps)
# set _bounds-attr in large batch
n_slices = np.zeros(shape=len(large_batch))
ctr = 0
for batch in batches:
n_slices[ctr: ctr + len(batch)] = batch.upper_bounds - batch.lower_bounds
ctr += len(batch)
large_batch._bounds = np.cumsum(np.insert(n_slices, 0, 0), dtype=np.int) # pylint: disable=protected-access
return large_batch
[docs] @classmethod
def merge(cls, batches, batch_size=None):
""" Concatenate list of batches and then split the result in two batches of sizes
(batch_size, sum(lens of batches) - batch_size)
Parameters
----------
batches : list of batches
batch_size : int
length of first resulting batch
Returns
-------
tuple of batches
(new_batch, rest_batch)
Notes
-----
Merge performs split (of middle-batch) and then two concats
because of speed considerations.
"""
batches = [batch for batch in batches if batch is not None]
batches = [batch for batch in batches if len(batch) > 0]
if batch_size is None:
return (cls.concat(batches), None)
if np.sum([len(batch) for batch in batches]) <= batch_size:
return (cls.concat(batches), None)
# find a batch that needs to be splitted (middle batch)
cum_len = 0
middle = None
middle_pos = None
for pos, batch in enumerate(batches):
cum_len += len(batch)
if cum_len >= batch_size:
middle = batch
middle_pos = pos
break
# split middle batch
left_middle, right_middle = cls.split(middle, len(middle) - cum_len + batch_size)
# form merged and rest-batches
merged = cls.concat(batches[:middle_pos] + [left_middle])
rest = cls.concat([right_middle] + batches[middle_pos + 1:])
return merged, rest
[docs] @action
def load(self, fmt='dicom', components=None, bounds=None, **kwargs): # pylint: disable=arguments-differ
""" Load 3d scans data in batch.
Parameters
----------
fmt : str
type of data. Can be 'dicom'|'blosc'|'raw'|'ndarray'
components : tuple, list, ndarray of strings or str
Contains names of batch component(s) that should be loaded.
As of now, works only if fmt='blosc'. If fmt != 'blosc', all
available components are loaded. If None and fmt = 'blosc', again,
all components are loaded.
bounds : ndarray(n_patients + 1, dtype=np.int) or None
Needed iff fmt='ndarray'. Bound-floors for items from a `skyscraper`
(stacked scans).
**kwargs
images : ndarray(n_patients * z, y, x) or None
Needed only if fmt = 'ndarray'
input array containing `skyscraper` (stacked scans).
origin : ndarray(n_patients, 3) or None
Needed only if fmt='ndarray'.
origins of scans in world coordinates.
spacing : ndarray(n_patients, 3) or None
Needed only if fmt='ndarray'
ndarray with spacings of patients along `z,y,x` axes.
Returns
-------
self
Examples
--------
DICOM example
initialize batch for storing batch of 3 patients with following IDs:
>>> index = FilesIndex(path="/some/path/*.dcm", no_ext=True)
>>> batch = CTImagesBatch(index)
>>> batch.load(fmt='dicom')
Ndarray example
images_array stores a set of 3d-scans concatted along 0-zxis, "skyscraper".
Say, it is a ndarray with shape (400, 256, 256)
bounds stores ndarray of last floors for each scan.
say, bounds = np.asarray([0, 100, 400])
>>> batch.load(fmt='ndarray', images=images_array, bounds=bounds)
"""
# if ndarray
if fmt == 'ndarray':
self._init_data(bounds=bounds, **kwargs)
elif fmt == 'dicom':
self._load_dicom() # pylint: disable=no-value-for-parameter
elif fmt == 'blosc':
components = self.components if components is None else components
# convert components_blosc to iterable
components = np.asarray(components).reshape(-1)
self._load_blosc(components=components) # pylint: disable=no-value-for-parameter
elif fmt == 'raw':
self._load_raw() # pylint: disable=no-value-for-parameter
else:
raise TypeError("Incorrect type of batch source")
return self
@inbatch_parallel(init='indices', post='_post_default', target='threads')
def _load_dicom(self, patient_id, **kwargs):
""" Read dicom file, load 3d-array and convert to Hounsfield Units (HU).
Notes
-----
Conversion to hounsfield unit scale using meta from dicom-scans is performed.
"""
# put 2d-scans for each patient in a list
patient_pos = self.index.get_pos(patient_id)
patient_folder = self.index.get_fullpath(patient_id)
list_of_dicoms = [dicom.read_file(os.path.join(patient_folder, s))
for s in os.listdir(patient_folder)]
list_of_dicoms.sort(key=lambda x: int(x.ImagePositionPatient[2]), reverse=True)
dicom_slice = list_of_dicoms[0]
intercept_pat = dicom_slice.RescaleIntercept
slope_pat = dicom_slice.RescaleSlope
self.spacing[patient_pos, ...] = np.asarray([float(dicom_slice.SliceThickness),
float(dicom_slice.PixelSpacing[0]),
float(dicom_slice.PixelSpacing[1])], dtype=np.float)
self.origin[patient_pos, ...] = np.asarray([float(dicom_slice.ImagePositionPatient[2]),
float(dicom_slice.ImagePositionPatient[0]),
float(dicom_slice.ImagePositionPatient[1])], dtype=np.float)
patient_data = np.stack([s.pixel_array for s in list_of_dicoms]).astype(np.int16)
patient_data[patient_data == AIR_HU] = 0
# perform conversion to HU
if slope_pat != 1:
patient_data = slope_pat * patient_data.astype(np.float64)
patient_data = patient_data.astype(np.int16)
patient_data += np.int16(intercept_pat)
return patient_data
def _prealloc_skyscraper_components(self, components, fmt='blosc'):
""" Read shapes of skyscraper-components dumped with blosc,
allocate memory for them, update self._bounds.
Used for more efficient load in terms of memory.
Parameters
---------
components : str or iterable
iterable of components we need to preload.
fmt : str
format in which components are stored on disk.
"""
if fmt != 'blosc':
raise NotImplementedError('Preload from {} not implemented yet'.format(fmt))
# make iterable out of components-arg
components = [components] if isinstance(components, str) else list(components)
# load shapes, perform memory allocation
for component in components:
shapes = np.zeros((len(self), 3), dtype=np.int)
for ix in self.indices:
filename = os.path.join(self.index.get_fullpath(ix), component, 'data.shape')
ix_pos = self._get_verified_pos(ix)
# read shape and put it into shapes
if not os.path.exists(filename):
raise OSError("Component {} for item {} cannot be found on disk".format(component, ix))
with open(filename, 'rb') as file:
shapes[ix_pos, :] = pickle.load(file)
# update bounds of items
# TODO: once bounds for other components are added, make sure they are updated here in the right way
self._bounds = np.cumsum(np.insert(shapes[:, 0], 0, 0), dtype=np.int)
# preallocate the component
skysc_shape = (self._bounds[-1], shapes[0, 1], shapes[0, 2])
setattr(self, component, np.zeros(skysc_shape))
def _init_load_blosc(self, **kwargs):
""" Init-function for load from blosc.
Parameters
----------
**kwargs
components : iterable of components that need to be loaded
Returns
-------
list
list of ids of batch-items
"""
# set images-component to 3d-array of zeroes if the component is to be updated
if 'images' in kwargs['components']:
self._prealloc_skyscraper_components('images')
return self.indices
@inbatch_parallel(init='_init_load_blosc', post='_post_default', target='async', update=False)
async def _load_blosc(self, ix, *args, **kwargs):
""" Read scans from blosc and put them into batch components
Parameters
----------
**kwargs
components : tuple
tuple of strings with names of components of data
that should be loaded into self
Notes
-----
NO conversion to HU is done for blosc
(because usually it's done before).
"""
for source in kwargs['components']:
# set correct extension for each component and choose a tool
# for debyting and (possibly) decoding it
if source in ['spacing', 'origin']:
ext = 'pkl'
unpacker = pickle.loads
else:
ext = 'blk'
def unpacker(byted):
""" Debyte and decode an ndarray
"""
debyted = blosc.unpack_array(byted)
# read the decoder and apply it
decod_path = os.path.join(self.index.get_fullpath(ix), source, 'data.decoder')
# if file with decoder not exists, assume that no decoding is needed
if os.path.exists(decod_path):
with open(decod_path, mode='rb') as file:
decoder = pickle.loads(file.read())
else:
decoder = lambda x: x
return decoder(debyted)
comp_path = os.path.join(self.index.get_fullpath(ix), source, 'data' + '.' + ext)
if not os.path.exists(comp_path):
raise OSError("File with component {} doesn't exist".format(source))
# read the component
async with aiofiles.open(comp_path, mode='rb') as file:
byted = await file.read()
# de-byte it with the chosen tool
component = unpacker(byted)
# update needed slice(s) of component
comp_pos = self.get_pos(None, source, ix)
getattr(self, source)[comp_pos] = component
return None
def _load_raw(self, **kwargs): # pylint: disable=unused-argument
""" Load scans from .raw images (with meta in .mhd)
Notes
-----
Method does NO conversion to HU
NO multithreading is used, as SimpleITK (sitk) lib crashes
in multithreading mode in experiments.
"""
list_of_arrs = []
for patient_id in self.indices:
raw_data = sitk.ReadImage(self.index.get_fullpath(patient_id))
patient_pos = self.index.get_pos(patient_id)
list_of_arrs.append(sitk.GetArrayFromImage(raw_data))
# *.mhd files contain information about scans' origin and spacing;
# however the order of axes there is inversed:
# so, we just need to reverse arrays with spacing and origin.
self.origin[patient_pos, :] = np.array(raw_data.GetOrigin())[::-1]
self.spacing[patient_pos, :] = np.array(raw_data.GetSpacing())[::-1]
new_data = np.concatenate(list_of_arrs, axis=0)
new_bounds = np.cumsum(np.array([len(a) for a in [[]] + list_of_arrs]))
self.images = new_data
self._bounds = new_bounds
return self
[docs] @action
@inbatch_parallel(init='_init_dump', post='_post_default', target='async', update=False)
async def dump(self, ix, dst, components=None, fmt='blosc', index_to_name=None, i8_encoding_mode=None):
""" Dump chosen ``components`` of scans' batcn in folder ``dst`` in specified format.
When some of the ``components`` are ``None``, a warning is printed and nothing is dumped.
By default (``components is None``) ``dump`` attempts to dump all components.
Parameters
----------
dst : str
destination-folder where all patients' data should be put
components : tuple, list, ndarray of strings or str
component(s) that we need to dump (smth iterable or string). If not
supplied, dump all components
fmt : 'blosc'
format of dump. Currently only blosc-format is supported;
in this case folder for each patient is created. Tree-structure of created
files is demonstrated in the example below.
index_to_name : callable or None
When supplied, should return str;
A function that relates each item's index to a name of item's folder.
That is, each item is dumped into os.path.join(dst, index_to_name(items_index)).
If None, no transformation is applied and the method attempts to use indices of batch-items
as names of items' folders.
i8_encoding_mode : int, str or dict
whether (and how) components of skyscraper-type should be cast to int8.
If None, no cast is performed. The cast allows to save space on disk and to speed up batch-loading.
However, it comes with loss of precision, as originally skyscraper-components are stored
in float32-format. Can be int: 0, 1, 2 or str/None: 'linear', 'quantization' or None.
0 or None disable the cast. 1 stands for 'linear', 2 - for 'quantization'.
Can also be component-wise dict of modes, e.g.: {'images': 'linear', 'masks': 0}.
Examples
--------
Initialize batch and load data
>>> ind = ['1ae34g90', '3hf82s76']
>>> batch = CTImagesBatch(ind)
>>> batch.load(...)
>>> batch.dump(dst='./data/blosc_preprocessed')
The command above creates following files:
- ./data/blosc_preprocessed/1ae34g90/images/data.blk
- ./data/blosc_preprocessed/1ae34g90/images/data.shape
- ./data/blosc_preprocessed/1ae34g90/spacing/data.pkl
- ./data/blosc_preprocessed/1ae34g90/origin/data.pkl
- ./data/blosc_preprocessed/3hf82s76/images/data.blk
- ./data/blosc_preprocessed/3hf82s76/images/data.shape
- ./data/blosc_preprocessed/3hf82s76/spacing/data.pkl
- ./data/blosc_preprocessed/3hf82s76/origin/data.pkl
"""
# if components-arg is not supplied, dump all components
if components is None:
components = self.components
if fmt != 'blosc':
raise NotImplementedError('Dump to {} is not implemented yet'.format(fmt))
# make sure that components is iterable
components = np.asarray(components).reshape(-1)
data_items = dict()
for component in components:
# get correct extension for the component
if component in ['spacing', 'origin']:
ext = 'pkl'
else:
ext = 'blk'
# get component belonging to the needed item, add it to items-dict
comp_pos = self.get_pos(None, component, ix)
data_items.update({component: [getattr(self, component)[comp_pos], ext]})
# set item-specific folder
item_subdir = ix if index_to_name is None else index_to_name(ix)
item_dir = os.path.join(dst, item_subdir)
return await dump_data(data_items, item_dir, i8_encoding_mode)
[docs] def get_pos(self, data, component, index):
""" Return a positon of an item for a given index in data
or in self.`component`.
Fetch correct position inside batch for an item, looks for it
in `data`, if provided, or in `component` in self.
Parameters
----------
data : None or ndarray
data from which subsetting is done.
If None, retrieve position from `component` of batch,
if ndarray, returns index.
component : str
name of a component, f.ex. 'images'.
if component provided, data should be None.
index : str or int
index of an item to be looked for.
may be key from dataset (str)
or index inside batch (int).
Returns
-------
int
Position of item
Notes
-----
This is an overload of get_pos from base Batch-class,
see corresponding docstring for detailed explanation.
"""
if data is None:
ind_pos = self._get_verified_pos(index)
if component == 'images':
return slice(self.lower_bounds[ind_pos], self.upper_bounds[ind_pos])
else:
return slice(ind_pos, ind_pos + 1)
else:
return index
def _get_verified_pos(self, index):
""" Get position of patient in batch.
Whatever index is passed in this method, it returns
corresponding index inside batch.
Parameters
----------
index : str or int
Can be either position of patient in self.images
or index from self.index. If int, it means that
index is already patient's position in Batch.
If str, it's handled as a key, and returns a position in batch.
If fetched position is out of bounds then Exception is generated.
Returns
-------
int
patient's position inside batch
"""
if isinstance(index, int):
if index < len(self) and index >= 0:
pos = index
else:
raise IndexError("Index is out of range")
else:
pos = self.index.get_pos(index)
return pos
@property
def images_shape(self):
""" Get shapes for all 3d scans in CTImagesBatch.
Returns
-------
ndarray
shapes of data for each patient, ndarray(patient_pos, 3)
"""
shapes = np.zeros((len(self), 3), dtype=np.int)
shapes[:, 0] = self.upper_bounds - self.lower_bounds
shapes[:, 1], shapes[:, 2] = self.slice_shape
return shapes
@property
def lower_bounds(self):
""" Get lower bounds of patients data in CTImagesBatch.
Returns
-------
ndarray
ndarray(n_patients,) containing
lower bounds of patients data along z-axis.
"""
return self._bounds[:-1]
@property
def upper_bounds(self):
""" Get upper bounds of patients data in CTImagesBatch.
Returns
-------
ndarray
ndarray(n_patients,) containing
upper bounds of patients data along z-axis.
"""
return self._bounds[1:]
@property
def slice_shape(self):
""" Get shape of slice in yx-plane.
Returns
-------
ndarray
ndarray([y_dim, x_dim],dtype=np.int) with shape of scan slice.
"""
return np.asarray(self.images.shape[1:])
[docs] def rescale(self, new_shape):
""" Recomputes spacing values for patients' data after resize.
Parameters
----------
new_shape : ndarray(dtype=np.int)
shape of patient 3d array after resize,
in format np.array([z_dim, y_dim, x_dim], dtype=np.int).
Returns
-------
ndarray
ndarray(n_patients, 3) with spacing values for each
patient along z, y, x axes.
"""
return (self.spacing * self.images_shape) / new_shape
def _reraise_worker_exceptions(self, worker_outputs):
""" Reraise exceptions coming from worker-functions, if there are any.
Parameters
----------
worker_outputs : list
list of workers' results
"""
if any_action_failed(worker_outputs):
all_errors = self.get_errors(worker_outputs)
raise RuntimeError("Failed parallelizing. Some of the workers failed with following errors: ", all_errors)
def _post_default(self, list_of_arrs, update=True, new_batch=False, **kwargs):
""" Gatherer outputs of different workers, update `images` component
Parameters
----------
list_of_arrs : list
list of ndarrays to be concated and put in a batch.images.
update : bool
if False, nothing is performed.
new_batch : bool
if False, empty batch is created,
if True, data is gathered, loaded and put into batch.images.
Returns
-------
batch
new batch, empty batch or self-batch.
Notes
-----
Output of each worker should correspond to individual patient.
"""
self._reraise_worker_exceptions(list_of_arrs)
res = self
if update:
new_data = np.concatenate(list_of_arrs, axis=0)
new_bounds = np.cumsum(np.array([len(a) for a in [[]] + list_of_arrs]))
params = dict(images=new_data, bounds=new_bounds,
origin=self.origin, spacing=self.spacing)
if new_batch:
batch = type(self)(self.index)
batch.load(fmt='ndarray', **params)
res = batch
else:
self._init_data(**params)
return res
def _post_components(self, list_of_dicts, **kwargs):
""" Gather outputs of different workers, update many components.
Parameters
----------
list_of_dicts : list
list of dicts {`component_name`: what_to_place_in_component}
Returns
-------
self
changes self's components
"""
self._reraise_worker_exceptions(list_of_dicts)
# if images is in dict, update bounds
if 'images' in list_of_dicts[0]:
list_of_images = [worker_res['images'] for worker_res in list_of_dicts]
new_bounds = np.cumsum(np.array([len(a) for a in [[]] + list_of_images]))
new_data = np.concatenate(list_of_images, axis=0)
params = dict(images=new_data, bounds=new_bounds,
origin=self.origin, spacing=self.spacing)
self._init_data(**params)
# loop over other components that we need to update
for component in list_of_dicts[0]:
if component == 'images':
pass
else:
# concatenate comps-outputs for different scans and update self
list_of_component = [worker_res[component] for worker_res in list_of_dicts]
new_comp = np.concatenate(list_of_component, axis=0)
setattr(self, component, new_comp)
return self
def _init_images(self, **kwargs):
""" Fetch args for loading `images` using inbatch_parallel.
Args-fetcher for parallelization using inbatch_parallel.
Returns
-------
list
list of patient's 3d arrays.
"""
return [self.get(patient_id, 'images') for patient_id in self.indices]
def _init_rebuild(self, **kwargs):
""" Fetch args for `images` rebuilding using inbatch_parallel.
Args-fetcher for parallelization using inbatch_parallel
Parameters
----------
**kwargs
shape : tuple, list or ndarray of int
(z,y,x)-shape of every image in image component after action is performed.
spacing : tuple, list or ndarray of float
(z,y,x)-spacing for each image. If supplied, assume that
unify_spacing is performed.
Returns
-------
list
list of arg-dicts for different workers
"""
if 'shape' in kwargs:
num_slices, y, x = kwargs['shape']
new_bounds = num_slices * np.arange(len(self) + 1)
new_data = np.zeros((num_slices * len(self), y, x))
else:
new_bounds = self._bounds
new_data = np.zeros_like(self.images)
all_args = []
for i in range(len(self)):
out_patient = new_data[new_bounds[i]: new_bounds[i + 1], :, :]
item_args = {'patient': self.get(i, 'images'),
'out_patient': out_patient,
'res': new_data}
# for unify_spacing
if 'spacing' in kwargs:
shape_after_resize = (self.images_shape * self.spacing
/ np.asarray(kwargs['spacing']))
shape_after_resize = np.rint(shape_after_resize).astype(np.int)
item_args['factor'] = self.spacing[i, :] / np.array(kwargs['spacing'])
item_args['shape_resize'] = shape_after_resize[i, :]
all_args += [item_args]
return all_args
def _init_dump(self, **kwargs):
""" Init function for dump.
Checks if all components that should be dumped are non-None. If some are None,
prints warning and makes sure that nothing is dumped.
Parameters
----------
**kwargs:
components : tuple, list, ndarray of strings or str
components that we need to dump
"""
components = kwargs.get('components', self.components)
# make sure that components is iterable
components = np.asarray(components).reshape(-1)
_empty = [component for component in components if not self._if_component_filled(component)]
# if some of the components for dump are empty, print warning and do not dump anything
if len(_empty) > 0:
logger.warning('Components %r are empty. Nothing is dumped!', _empty)
return []
else:
return self.indices
def _post_rebuild(self, all_outputs, new_batch=False, **kwargs):
""" Gather outputs of different workers for actions, rebuild `images` component.
Parameters
----------
all_outputs : list
list of outputs. Each item is given by tuple
new_batch : bool
if True, returns new batch with data agregated
from all_ouputs. if False, changes self.
**kwargs
shape : list, tuple or ndarray of int
(z,y,x)-shape of every image in image component after action is performed.
spacing : tuple, list or ndarray of float
(z,y,x)-spacing for each image. If supplied, assume that
unify_spacing is performed.
"""
self._reraise_worker_exceptions(all_outputs)
new_bounds = np.cumsum([patient_shape[0] for _, patient_shape
in [[0, (0, )]] + all_outputs])
# each worker returns the same ref to the whole res array
new_data, _ = all_outputs[0]
# recalculate new_attrs of a batch
# for resize/unify_spacing: if shape is supplied, assume post
# is for resize or unify_spacing
if 'shape' in kwargs:
new_spacing = self.rescale(kwargs['shape'])
else:
new_spacing = self.spacing
# for unify_spacing: if spacing is supplied, assume post
# is for unify_spacing
if 'spacing' in kwargs:
# recalculate origin, spacing
shape_after_resize = np.rint(self.images_shape * self.spacing
/ np.asarray(kwargs['spacing']))
overshoot = shape_after_resize - np.asarray(kwargs['shape'])
new_spacing = self.rescale(new_shape=shape_after_resize)
new_origin = self.origin + new_spacing * (overshoot // 2)
else:
new_origin = self.origin
# build/update batch with new data and attrs
params = dict(images=new_data, bounds=new_bounds,
origin=new_origin, spacing=new_spacing)
if new_batch:
batch_res = type(self)(self.index)
batch_res.load(fmt='ndarray', **params)
return batch_res
else:
self._init_data(**params)
return self
[docs] @action
@inbatch_parallel(init='_init_rebuild', post='_post_rebuild', target='threads')
def resize(self, patient, out_patient, res, shape=(128, 256, 256), method='pil-simd',
axes_pairs=None, resample=None, order=3, *args, **kwargs):
""" Resize (change shape of) each CT-scan in the batch.
When called from a batch, changes this batch.
Parameters
----------
shape : tuple, list or ndarray of int
(z,y,x)-shape that should be AFTER resize.
Note, that ct-scan dim_ordering also should be `z,y,x`
method : str
interpolation package to be used. Either 'pil-simd' or 'scipy'.
Pil-simd ensures better quality and speed on configurations
with average number of cores. On the contrary, scipy is better scaled and
can show better performance on systems with large number of cores
axes_pairs : None or list/tuple of tuples with pairs
pairs of axes that will be used for performing pil-simd resize,
as this resize is made in 2d. Min number of pairs to use is 1,
at max there can be 6 pairs. If None, set to ((0, 1), (1, 2)).
The more pairs one uses, the more precise is the result.
(and computation takes more time).
resample : filter of pil-simd resize. By default set to bilinear. Can be any of filters
supported by PIL.Image.
order : the order of scipy-interpolation (<= 5)
large value improves precision, but slows down the computaion.
Examples
--------
>>> shape = (128, 256, 256)
>>> batch = batch.resize(shape=shape, order=2, method='scipy')
>>> batch = batch.resize(shape=shape, resample=PIL.Image.BILINEAR)
"""
if method == 'scipy':
args_resize = dict(patient=patient, out_patient=out_patient, res=res, order=order)
return resize_scipy(**args_resize)
elif method == 'pil-simd':
args_resize = dict(input_array=patient, output_array=out_patient,
res=res, axes_pairs=axes_pairs, resample=resample)
return resize_pil(**args_resize)
[docs] @action
@inbatch_parallel(init='_init_rebuild', post='_post_rebuild', target='threads')
def unify_spacing(self, patient, out_patient, res, factor,
shape_resize, spacing=(1, 1, 1), shape=(128, 256, 256),
method='pil-simd', order=3, padding='edge', axes_pairs=None,
resample=None, *args, **kwargs):
""" Unify spacing of all patients.
Resize all patients to meet `spacing`, then crop/pad resized array to meet `shape`.
Parameters
----------
spacing : tuple, list or ndarray of float
(z,y,x)-spacing after resize.
Should be passed as key-argument.
shape : tuple, list or ndarray of int
(z,y,x)-shape after crop/pad.
Should be passed as key-argument.
method : str
interpolation method ('pil-simd' or 'resize').
Should be passed as key-argument.
See CTImagesBatch.resize for more information.
order : None or int
order of scipy-interpolation (<=5), if used.
Should be passed as key-argument.
padding : str
mode of padding, any supported by np.pad.
Should be passed as key-argument.
axes_pairs : tuple, list of tuples with pairs
pairs of axes that will be used consequentially
for performing pil-simd resize.
Should be passed as key-argument.
resample : None or str
filter of pil-simd resize.
Should be passed as key-argument
patient : str
index of patient, that worker is handling.
Note: this argument is passed by inbatch_parallel
out_patient : ndarray
result of individual worker after action.
Note: this argument is passed by inbatch_parallel
res : ndarray
New `images` to replace data inside `images` component.
Note: this argument is passed by inbatch_parallel
factor : tuple
(float), factor to make resize by.
Note: this argument is passed by inbatch_parallel
shape_resize : tuple
It is possible to provide `shape_resize` argument (shape after resize)
instead of spacing. Then array with `shape_resize`
will be cropped/padded for shape to = `shape` arg.
Note that this argument is passed by inbatch_parallel
Notes
-----
see CTImagesBatch.resize for more info about methods' params.
Examples
--------
>>> shape = (128, 256, 256)
>>> batch = batch.unify_spacing(shape=shape, spacing=(1.0, 1.0, 1.0),
order=2, method='scipy', padding='reflect')
>>> batch = batch.unify_spacing(shape=shape, spacing=(1.0, 1.0, 1.0),
resample=PIL.Image.BILINEAR)
"""
if method == 'scipy':
args_resize = dict(patient=patient, out_patient=out_patient,
res=res, order=order, factor=factor, padding=padding)
return resize_scipy(**args_resize)
elif method == 'pil-simd':
args_resize = dict(input_array=patient, output_array=out_patient,
res=res, axes_pairs=axes_pairs, resample=resample,
shape_resize=shape_resize, padding=padding)
return resize_pil(**args_resize)
[docs] @action
@inbatch_parallel(init='indices', post='_post_default', update=False, target='threads')
def rotate(self, index, angle, components='images', axes=(1, 2), random=True, **kwargs):
""" Rotate 3D images in batch on specific angle in plane.
Parameters
----------
angle : float
degree of rotation.
components : tuple, list, ndarray of strings or str
name(s) of components to rotate each item in it.
axes : tuple, list or ndarray of int
(int, int), plane of rotation specified by two axes (zyx-ordering).
random : bool
if True, then degree specifies maximum angle of rotation.
Returns
-------
ndarray
ndarray of 3D rotated image.
Notes
-----
zero padding automatically added after rotation.
Use this action in the end of pipelines for purposes of augmentation.
E.g., after :func:`~radio.preprocessing.ct_masked_batch.CTImagesMaskedBatch.sample_nodules`
Examples
--------
Rotate images on 90 degrees:
>>> batch = batch.rotate(angle=90, axes=(1, 2), random=False)
Random rotation with maximum angle:
>>> batch = batch.rotate(angle=30, axes=(1, 2))
"""
_components = np.asarray(components).reshape(-1)
_angle = angle * np.random.rand() if random else angle
for comp in _components:
data = self.get(index, comp)
rotate_3D(data, _angle, axes)
@inbatch_parallel(init='_init_images', post='_post_default', target='threads', new_batch=True)
def _make_xip(self, image, depth, stride=2, mode='max',
projection='axial', padding='reflect', *args, **kwargs):
""" Make intensity projection (maximum, minimum, mean or median).
Notice that axis is chosen according to projection argument.
Parameters
----------
depth : int
number of slices over which xip operation is performed.
stride : int
stride-step along projection dimension.
mode : str
Possible values are 'max', 'min', 'mean' or 'median'.
projection : str
Possible values: 'axial', 'coronal', 'sagital'.
In case of 'coronal' and 'sagital' projections tensor
will be transposed from [z,y,x] to [x,z,y] and [y,z,x].
padding : str
mode of padding that will be passed in numpy.padding function.
"""
return make_xip_numba(image, depth, stride, mode, projection, padding)
[docs] @action
def make_xip(self, depth, stride=1, mode='max', projection='axial', padding='reflect', **kwargs):
""" Make intensity projection (maximum, minimum, mean or median).
Notice that axis is chosen according to projection argument.
Parameters
----------
depth : int
number of slices over which xip operation is performed.
stride : int
stride-step along projection dimension.
mode : str
Possible values are 'max', 'min', 'mean' or 'median'.
projection : str
Possible values: 'axial', 'coronal', 'sagital'.
In case of 'coronal' and 'sagital' projections tensor
will be transposed from [z,y,x] to [x,z,y] and [y,z,x].
padding : str
mode of padding that will be passed in numpy.padding function.
"""
output_batch = self._make_xip(depth=depth, stride=stride, mode=mode, # pylint: disable=no-value-for-parameter
projection=projection, padding=padding)
output_batch.spacing = self.rescale(output_batch.images_shape)
return output_batch
[docs] @inbatch_parallel(init='_init_rebuild', post='_post_rebuild', target='threads', new_batch=True)
def calc_lung_mask(self, patient, out_patient, res, erosion_radius, **kwargs): # pylint: disable=unused-argument, no-self-use
""" Return a mask for lungs
Parameters
----------
erosion_radius : int
radius of erosion to be performed.
"""
return calc_lung_mask_numba(patient, out_patient, res, erosion_radius)
[docs] @action
def segment(self, erosion_radius=2, **kwargs):
""" Segment lungs' content from 3D array.
Parameters
---------
erosion_radius : int
radius of erosion to be performed.
Returns
-------
batch
Notes
-----
Sets HU of every pixel outside lungs to DARK_HU = -2000.
Examples
--------
>>> batch = batch.segment(erosion_radius=4, num_threads=20)
"""
# get mask with specified params, apply it to scans
mask_batch = self.calc_lung_mask(erosion_radius=erosion_radius, **kwargs) # pylint: disable=no-value-for-parameter
lungs_mask = mask_batch.images
self.images *= lungs_mask
# reverse the mask and set not-lungs to DARK_HU
result_mask = 1 - lungs_mask
result_mask *= DARK_HU
self.images += result_mask
return self
[docs] @action
def central_crop(self, crop_size, **kwargs):
""" Make crop of crop_size from center of images.
Parameters
----------
crop_size : tuple, list or ndarray of int
(z,y,x)-shape of crop.
Returns
-------
batch
"""
crop_size = np.asarray(crop_size).reshape(-1)
crop_halfsize = np.rint(crop_size / 2)
img_shapes = [np.asarray(self.get(i, 'images').shape) for i in range(len(self))]
if any(np.any(shape < crop_size) for shape in img_shapes):
raise ValueError("Crop size must be smaller than size of inner 3D images")
cropped_images = []
for i in range(len(self)):
image = self.get(i, 'images')
cropped_images.append(make_central_crop(image, crop_size))
self._bounds = np.cumsum([0] + [crop_size[0]] * len(self))
self.images = np.concatenate(cropped_images, axis=0)
self.origin = self.origin + self.spacing * crop_halfsize
return self
[docs] def get_patches(self, patch_shape, stride, padding='edge', data_attr='images'):
""" Extract patches of patch_shape with specified stride.
Parameters
----------
patch_shape : tuple, list or ndarray of int
(z,y,x)-shape of a single patch.
stride : tuple, list or ndarray of int
(z,y,x)-stride to slide over each patient's data.
padding : str
padding-type (see doc of np.pad for available types).
data_attr : str
component to get data from.
Returns
-------
ndarray
4d-ndaray of patches; first dimension enumerates patches
Notes
-----
Shape of all patients data is needed to be the same at this step,
resize/unify_spacing is required before.
"""
patch_shape = np.asarray(patch_shape).reshape(-1)
stride = np.asarray(stride).reshape(-1)
img_shape = self.images_shape[0]
data_4d = np.reshape(getattr(self, data_attr), (-1, *img_shape))
# add padding if necessary
pad_width = calc_padding_size(img_shape, patch_shape, stride)
if pad_width is not None:
data_padded = np.pad(data_4d, pad_width, mode=padding)
else:
data_padded = data_4d
# init tensor with patches
num_sections = (np.asarray(data_padded.shape[1:]) - patch_shape) // stride + 1
patches = np.zeros(shape=(len(self), np.prod(num_sections), *patch_shape))
# put patches into the tensor
fake = np.zeros(len(self))
get_patches_numba(data_padded, patch_shape, stride, patches, fake)
patches = np.reshape(patches, (len(self) * np.prod(num_sections), *patch_shape))
return patches
[docs] def load_from_patches(self, patches, stride, scan_shape, data_attr='images'):
""" Get skyscraper from 4d-array of patches, put it to `data_attr` component in batch.
Let reconstruct original skyscraper from patches (if same arguments are passed)
Parameters
----------
patches : ndarray
4d-array of patches, with dims: `(num_patches, z, y, x)`.
scan_shape : tuple, list or ndarray of int
(z,y,x)-shape of individual scan (should be same for all scans).
stride : tuple, list or ndarray of int
(z,y,x)-stride step used for gathering data from patches.
data_attr : str
batch component name to store new data.
Notes
-----
If stride != patch.shape(), averaging of overlapped regions is used.
`scan_shape`, patches.shape(), `stride` are used to infer the number of items
in new skyscraper. If patches were padded, padding is removed for skyscraper.
"""
scan_shape = np.asarray(scan_shape).reshape(-1)
stride = np.asarray(stride).reshape(-1)
patch_shape = np.asarray(patches.shape[1:]).reshape(-1)
# infer what padding was applied to scans when extracting patches
pad_width = calc_padding_size(scan_shape, patch_shape, stride)
# if padding is non-zero, adjust the shape of scan
if pad_width is not None:
shape_delta = np.asarray([before + after for before, after in pad_width[1:]])
else:
shape_delta = np.zeros(3).astype('int')
scan_shape_adj = scan_shape + shape_delta
# init 4d tensor and put assembled scans into it
data_4d = np.zeros((len(self), *scan_shape_adj))
patches = np.reshape(patches, (len(self), -1, *patch_shape))
fake = np.zeros(len(self))
assemble_patches(patches, stride, data_4d, fake)
# crop (perform anti-padding) if necessary
if pad_width is not None:
data_shape = data_4d.shape
slc_z = slice(pad_width[1][0], data_shape[1] - pad_width[1][1])
slc_y = slice(pad_width[2][0], data_shape[2] - pad_width[2][1])
slc_x = slice(pad_width[3][0], data_shape[3] - pad_width[3][1])
data_4d = data_4d[:, slc_z, slc_y, slc_x]
# reshape 4d-data to skyscraper form and put it into needed attr
data_4d = data_4d.reshape((len(self) * scan_shape[0], *scan_shape[1:]))
setattr(self, data_attr, data_4d)
[docs] @action
def normalize_hu(self, min_hu=-1000, max_hu=400):
""" Normalize HU-densities to interval [0, 255].
Trim HU that are outside range [min_hu, max_hu], then scale to [0, 255].
Parameters
----------
min_hu : int
minimum value for hu that will be used as trimming threshold.
max_hu : int
maximum value for hu that will be used as trimming threshold.
Returns
-------
batch
Examples
--------
>>> batch = batch.normalize_hu(min_hu=-1300, max_hu=600)
"""
# trimming and scaling to [0, 1]
self.images = (self.images - min_hu) / (max_hu - min_hu)
self.images[self.images > 1] = 1.
self.images[self.images < 0] = 0.
# scaling to [0, 255]
self.images *= 255
return self
[docs] @action
@inbatch_parallel(init='_init_rebuild', post='_post_rebuild', target='threads')
def flip(self, patient, out_patient, res): # pylint: disable=no-self-use
""" Invert the order of slices for each patient
Returns
-------
batch
Examples
--------
>>> batch = batch.flip()
"""
return flip_patient_numba(patient, out_patient, res)
[docs] def get_axial_slice(self, person_number, slice_height):
""" Get axial slice (e.g., for plots)
Parameters
----------
person_number : str or int
Can be either index (int) of person in the batch
or patient_id (str)
slice_height : float
scaled from 0 to 1 number of slice.
e.g. 0.7 means that we take slice with number
int(0.7 * number of slices for person)
Returns
-------
ndarray (view)
Examples
--------
Here self.index[5] usually smth like 'a1de03fz29kf6h2'
>>> patch = batch.get_axial_slice(5, 0.6)
>>> patch = batch.get_axial_slice(self.index[5], 0.6)
"""
margin = int(slice_height * self.get(person_number, 'images').shape[0])
patch = self.get(person_number, 'images')[margin, :, :]
return patch