Source code for pyblaze.nn.data.zip

[docs]class ZipDataLoader: """ A data loader that zips together two underlying data loaders. The data loaders must be sampling the same batch size and :code:`drop_last` must be set to `True` on data loaders that sample from a fixed-size dataset. Whenever one of the data loaders has a fixed size, this data loader defines a length. This length is given as the minimum of the both lengths divided by their respective counts. A common use case for this class are Wasserstein GANs where the critic is trained for multiple iterations for each data batch. """
[docs] def __init__(self, lhs_loader, rhs_loader, lhs_count=1, rhs_count=1): """ Initializes a new data loader. Parameters ---------- lhs_dataset: torch.utils.data.DataLoader The dataset to sample from for the first item of the data tuple. rhs_dataset: torch.utils.data.DataLoader The dataset to sample from for the second item of the data tuple. lhs_count: int The number of items to sample for the first item of the data tuple. rhs_count: int The number of items to sample for the second item of the data tuple. """ if lhs_loader.batch_size != rhs_loader.batch_size: raise ValueError("Both given data loaders must have the same batch size.") self.lhs_loader = lhs_loader self.rhs_loader = rhs_loader self.lhs_count = lhs_count self.rhs_count = rhs_count
def __len__(self): result = None try: result = len(self.lhs_loader) // self.lhs_count except: # pylint: disable=bare-except pass try: rhs_len = len(self.rhs_loader) // self.rhs_count if result is None: result = rhs_len else: result = min(result, rhs_len) except: # pylint: disable=bare-except pass if result is None: raise TypeError("__len__ not implemented for instance of ZipDataLoader") return result def __iter__(self): lhs_it = iter(self.lhs_loader) rhs_it = iter(self.rhs_loader) while True: try: lhs_items = [next(lhs_it) for _ in range(self.lhs_count)] rhs_items = [next(rhs_it) for _ in range(self.rhs_count)] except StopIteration: return yield lhs_items, rhs_items