Source code for sccross.models.data

r"""
Data handling utilities
"""

import functools
import multiprocessing
import operator
import os
import queue
import signal
from math import ceil
from typing import Any, List, Mapping, Optional
import numpy as np
import scipy.sparse
import torch
from anndata._core.sparse_dataset import SparseDataset
from ..utils import config, get_rs, logged, processes, Array, RandomState



#---------------------------------- Datasets -----------------------------------

[docs]@logged class Dataset(torch.utils.data.Dataset): r""" Abstract dataset interface extending that of :class:`torch.utils.data.Dataset` Parameters ---------- getitem_size Unitary fetch size for each __getitem__ call """ def __init__(self, getitem_size: int = 1) -> None: super().__init__() self.getitem_size = getitem_size self.shuffle_seed: Optional[int] = None self.seed_queue: Optional[multiprocessing.Queue] = None self.propose_queue: Optional[multiprocessing.Queue] = None self.propose_cache: Mapping[int, Any] = {} @property def has_workers(self) -> bool: r""" Whether background shuffling workers have been registered """ self_processes = processes[id(self)] pl = bool(self_processes) sq = self.seed_queue is not None pq = self.propose_queue is not None if not pl == sq == pq: raise RuntimeError("Background shuffling seems broken!") return pl and sq and pq
[docs] def prepare_shuffle(self, num_workers: int = 1, random_seed: int = 0) -> None: r""" Prepare dataset for custom shuffling Parameters ---------- num_workers Number of background workers for data shuffling random_seed Initial random seed (will increase by 1 with every shuffle call) """ if self.has_workers: self.clean() self_processes = processes[id(self)] self.shuffle_seed = random_seed if num_workers: self.seed_queue = multiprocessing.Queue() self.propose_queue = multiprocessing.Queue() for i in range(num_workers): p = multiprocessing.Process(target=self.shuffle_worker) p.start() self.logger.debug("Started background process: %d", p.pid) self_processes[p.pid] = p self.seed_queue.put(self.shuffle_seed + i)
[docs] def shuffle(self) -> None: r""" Custom shuffling """ if self.has_workers: self_processes = processes[id(self)] self.seed_queue.put(self.shuffle_seed + len(self_processes)) # Look ahead while self.shuffle_seed not in self.propose_cache: shuffle_seed, shuffled = self.propose_queue.get() self.propose_cache[shuffle_seed] = shuffled self.accept_shuffle(self.propose_cache.pop(self.shuffle_seed)) else: self.accept_shuffle(self.propose_shuffle(self.shuffle_seed)) self.shuffle_seed += 1
[docs] def shuffle_worker(self) -> None: r""" Background shuffle worker """ signal.signal(signal.SIGINT, signal.SIG_IGN) while True: seed = self.seed_queue.get() if seed is None: self.propose_queue.put((None, os.getpid())) break self.propose_queue.put((seed, self.propose_shuffle(seed)))
[docs] def propose_shuffle(self, seed: int) -> Any: r""" Propose shuffling using a given random seed Parameters ---------- seed Random seed Returns ------- shuffled Shuffled result """ raise NotImplementedError # pragma: no cover
[docs] def accept_shuffle(self, shuffled: Any) -> None: r""" Accept shuffling result Parameters ---------- shuffled Shuffled result """ raise NotImplementedError # pragma: no cover
[docs] def clean(self) -> None: r""" Clean up multi-process resources used in custom shuffling """ self_processes = processes[id(self)] if not self.has_workers: return for _ in self_processes: self.seed_queue.put(None) self.propose_cache.clear() while self_processes: try: first, second = self.propose_queue.get( timeout=config.FORCE_TERMINATE_WORKER_PATIENCE ) except queue.Empty: break if first is not None: continue pid = second self_processes[pid].join() self.logger.debug("Joined background process: %d", pid) del self_processes[pid] for pid in list(self_processes.keys()): # If some background processes failed to exit gracefully self_processes[pid].terminate() self_processes[pid].join() self.logger.debug("Terminated background process: %d", pid) del self_processes[pid] self.propose_queue = None self.seed_queue = None
def __del__(self) -> None: self.clean()
[docs]@logged class ArrayDataset(Dataset): r""" Array dataset for :class:`numpy.ndarray` and :class:`scipy.sparse.spmatrix` objects. Different arrays are considered as unpaired, and thus do not need to have identical sizes in the first dimension. Smaller arrays are recycled. Also, data fetched from this dataset are automatically densified. Parameters ---------- *arrays An arbitrary number of data arrays Note ---- We keep using arrays because sparse tensors do not support slicing. Arrays are only converted to tensors after minibatch slicing. """ def __init__(self, *arrays: Array, getitem_size: int = 1) -> None: super().__init__(getitem_size=getitem_size) self.sizes = None self.size = None self.view_idx = None self.shuffle_idx = None self.arrays = arrays @property def arrays(self) -> List[Array]: r""" Internal array objects """ return self._arrays @arrays.setter def arrays(self, arrays: List[Array]) -> None: self.sizes = [array.shape[0] for array in arrays] if min(self.sizes) == 0: raise ValueError("Empty array is not allowed!") self.size = max(self.sizes) self.view_idx = [np.arange(s) for s in self.sizes] self.shuffle_idx = self.view_idx self._arrays = arrays def __len__(self) -> int: return ceil(self.size / self.getitem_size) def __getitem__(self, index: int) -> List[torch.Tensor]: index = np.arange( index * self.getitem_size, min((index + 1) * self.getitem_size, self.size) ) return [ torch.as_tensor(a[self.shuffle_idx[i][np.mod(index, self.sizes[i])]].toarray()) if scipy.sparse.issparse(a) or isinstance(a, SparseDataset) else torch.as_tensor(a[self.shuffle_idx[i][np.mod(index, self.sizes[i])]]) for i, a in enumerate(self.arrays) ]
[docs] def propose_shuffle(self, seed: int) -> List[np.ndarray]: rs = get_rs(seed) return [rs.permutation(view_idx) for view_idx in self.view_idx]
[docs] def accept_shuffle(self, shuffled: List[np.ndarray]) -> None: self.shuffle_idx = shuffled
[docs] def random_split( self, fractions: List[float], random_state: RandomState = None ) -> List["ArrayDataset"]: r""" Randomly split the dataset into multiple subdatasets according to given fractions. Parameters ---------- fractions Fraction of each split random_state Random state Returns ------- subdatasets A list of splitted subdatasets """ if min(fractions) <= 0: raise ValueError("Fractions should be greater than 0!") if sum(fractions) != 1: raise ValueError("Fractions do not sum to 1!") rs = get_rs(random_state) cum_frac = np.cumsum(fractions) subdatasets = [ ArrayDataset( *self.arrays, getitem_size=self.getitem_size ) for _ in fractions ] for j, view_idx in enumerate(self.view_idx): view_idx = rs.permutation(view_idx) split_pos = np.round(cum_frac * view_idx.size).astype(int) split_idx = np.split(view_idx, split_pos[:-1]) # Last pos produces an extra empty split for i, idx in enumerate(split_idx): subdatasets[i].sizes[j] = len(idx) subdatasets[i].view_idx[j] = idx subdatasets[i].shuffle_idx[j] = idx return subdatasets
#-------------------------------- Data loaders ---------------------------------
[docs]class DataLoader(torch.utils.data.DataLoader): r""" Custom data loader that manually shuffles the internal dataset before each round of iteration (see :class:`torch.utils.data.DataLoader` for usage) """ def __init__(self, dataset: Dataset, **kwargs) -> None: super().__init__(dataset, **kwargs) self.collate_fn = self._collate self.shuffle = kwargs["shuffle"] if "shuffle" in kwargs else False def __iter__(self) -> "DataLoader": if self.shuffle: self.dataset.shuffle() # Customized shuffling return super().__iter__() @staticmethod def _collate(batch): return tuple(map(lambda x: torch.cat(x, dim=0), zip(*batch))) @staticmethod def _collate_graph(batch): eidx, ewt, esgn = zip(*batch) eidx = torch.cat(eidx, dim=1) ewt = torch.cat(ewt, dim=0) esgn = torch.cat(esgn, dim=0) return eidx, ewt, esgn
[docs]class ParallelDataLoader: r""" Parallel data loader Parameters ---------- *data_loaders An arbitrary number of data loaders cycle_flags Whether each data loader should be cycled in case they are of different lengths, by default none of them are cycled. """ def __init__( self, *data_loaders: DataLoader, cycle_flags: Optional[List[bool]] = None ) -> None: cycle_flags = cycle_flags or [False] * len(data_loaders) if len(cycle_flags) != len(data_loaders): raise ValueError("Invalid cycle flags!") self.cycle_flags = cycle_flags self.data_loaders = list(data_loaders) self.num_loaders = len(self.data_loaders) self.iterators = None def __iter__(self) -> "ParallelDataLoader": self.iterators = [iter(loader) for loader in self.data_loaders] return self def _next(self, i: int) -> List[torch.Tensor]: try: return next(self.iterators[i]) except StopIteration as e: if self.cycle_flags[i]: self.iterators[i] = iter(self.data_loaders[i]) return next(self.iterators[i]) raise e def __next__(self) -> List[torch.Tensor]: return functools.reduce( operator.add, [self._next(i) for i in range(self.num_loaders)] )