Source code for probflow.data.data_generator

import multiprocessing as mp
from abc import abstractmethod

from probflow.utils.base import BaseDataGenerator


[docs]class DataGenerator(BaseDataGenerator): """Abstract base class for a data generator, which uses multiprocessing to load the data in parallel. TODO User needs to implement: * :meth:`~__init__` * :meth:`~n_samples` * :meth:`~batch_size` * :meth:`~get_batch` And can optionally implement: * :meth:`~on_epoch_start` * :meth:`~on_epoch_end` """ def __init__(self, num_workers=None): self.num_workers = num_workers
[docs] @abstractmethod def get_batch(self, index): """Generate one batch of data"""
def __getitem__(self, index): """Generate one batch of data""" # No multiprocessing if self.num_workers is None: return self.get_batch(index) # Multiprocessing else: # Start the next worker pid = index + self.num_workers if pid < len(self): self._workers[pid].start() # Return data from the multiprocessing queue return self._queue.get() def __iter__(self): """Get an iterator over batches""" # Multiprocessing? if self.num_workers is not None: def get_data(index, queue): queue.put(self.get_batch(index)) # Create the queue and worker processes self._queue = mp.Queue() self._workers = [ mp.Process(target=get_data, args=(i, self._queue)) for i in range(len(self)) ] # Start the first num_workers workers for i in range(min(self.num_workers, len(self))): self._workers[i].start() # Keep track of what batch we're on self._batch = -1 # Return iterator return self def __next__(self): """Get the next batch""" self._batch += 1 if self._batch < len(self): return self[self._batch] else: raise StopIteration()