import inspect
import logging
import warnings
from typing import Any, Dict, List, Sequence, Union
from torch.utils.data import (
BatchSampler,
DataLoader,
Dataset,
DistributedSampler,
Sampler,
)
from oml.ddp.utils import WarningDDP, is_ddp
from oml.interfaces.samplers import IBatchSampler
TAllSamplers = Union[BatchSampler, Sampler, IBatchSampler]
class _Sampler2Dataset(Dataset):
def __init__(self, sampler: TAllSamplers):
# We read sampler in __getitem__ due to the seed between calling of __init__ and __getitem__ can be changed
self.sampler_read = None
self.sampler = sampler
def __getitem__(self, item: int) -> Union[int, List[int]]:
if self.sampler_read is None:
self.sampler_read = list(self.sampler) # type: ignore
return self.sampler_read[item] # type: ignore
def __len__(self) -> int:
return len(self.sampler) # type: ignore
[docs]class DDPSamplerWrapper(DistributedSampler):
"""
This is a wrapper to allow using custom sampler in DDP mode.
Default `DistributedSampler` allows us to build a sampler for a dataset in DDP mode.
Usually we can easily replace default `SequentialSampler` [when ``DataLoader(shuffle=False, ...)``] and
`RandomSampler` [when ``DataLoader(shuffle=True, ...)``] with `DistributedSampler`. But for the custom sampler,
we need an extra wrapper.
Thus, this wrapper distributes indices produced by sampler among several devices for further usage.
"""
[docs] def __init__(
self, sampler: TAllSamplers, shuffle_samples_between_gpus: bool = True, pad_data_to_num_gpus: bool = True
):
"""
Args:
sampler: Sequential or batch sampler
pad_data_to_num_gpus: When using DDP we should manage behavior with the last batch, because each device
should have the same amount of data. If the sampler length is not evenly divisible by the number of
devices, we must duplicate part of the data (``pad_data_to_num_gpus=True``), or discard part of the
data (``pad_data_to_num_gpus=False``).
shuffle_samples_between_gpus: shuffle available indices before feeding them to GPU. Note, that shuffle
inside GPU after the feeding will be used according to behavior of the sampler.
Note: Wrapper can be used with both the default `SequentialSampler` or `RandomSampler` from `PyTorch` and
with some custom sampler.
"""
super().__init__(
dataset=_Sampler2Dataset(sampler), shuffle=shuffle_samples_between_gpus, drop_last=not pad_data_to_num_gpus
)
self.seed_shift_per_epoch = 0
self.sampler = sampler
[docs] def _reload(self) -> None:
"""
We need to re-instantiate the wrapper in order to update the available indices for the new epoch.
We don't perform this step on the epoch 0, because we want to be comparable with no DDP setup there.
"""
if self.seed_shift_per_epoch > 0:
super().__init__(dataset=_Sampler2Dataset(self.sampler), shuffle=self.shuffle, drop_last=self.drop_last)
self.set_epoch(self.seed_shift_per_epoch)
self.seed_shift_per_epoch += 1
def __iter__(self) -> TAllSamplers:
self._reload()
for sampler_idx in super().__iter__():
yield self.dataset[sampler_idx]
def extract_loader_parameters(loader: DataLoader, ignore_data_related_parameters: bool = False) -> Dict[str, Any]:
"""
The function extracts parameters from dataloader, such as `collate_fn`, `num_workers`, etc, and automatically
handles some new parameters, e.g. `prefetch_factor`.
Args:
loader: loader from which parameters are extracted
ignore_data_related_parameters: The flag allows you to ignore parameters, related to data, batch content,
and samplers.
"""
ignore_fields = ["self"]
if ignore_data_related_parameters:
ignore_fields.extend(["sampler", "batch_sampler", "drop_last", "shuffle", "batch_size", "dataset"])
extracted = {}
signature = inspect.signature(DataLoader.__init__)
for parameter in signature.parameters:
if parameter not in ignore_fields:
if hasattr(loader, parameter):
extracted[parameter] = getattr(loader, parameter)
assert len(extracted)
return extracted
[docs]def patch_dataloader_to_ddp(loader: DataLoader) -> DataLoader:
"""
Function inspects loader and modifies sampler for working in DDP mode.
Note:
We ALWAYS use the padding of samples (in terms of the number of batches or number of samples per epoch) in
order to use the same amount of data for each device in DDP. Thus, the behavior with and without DDP may be
slightly different (e.g. metrics values).
"""
if is_ddp():
kwargs_loader = extract_loader_parameters(loader, ignore_data_related_parameters=True)
# If you don't spectify batch_sampler, PyTorch automatically creates default BatchSampler. In this case we
# need convert to DDP only sampler (your custom sampler / default SequentialSampler or RandomSampler, which
# PyTorch creates if sampler=None). We don't use `isinstance(...)` for `if` statement because we need exactly
# class BatchSampler, ignoring any inheritance
if type(loader.batch_sampler) is BatchSampler:
ddp_sampler = DDPSamplerWrapper(
sampler=loader.sampler, shuffle_samples_between_gpus=False, pad_data_to_num_gpus=True
)
patched_loader = DataLoader(
dataset=loader.dataset,
sampler=ddp_sampler,
batch_size=loader.batch_size,
drop_last=loader.drop_last,
**kwargs_loader,
)
sampler_info = f"'{loader.sampler.__class__.__name__}' sampler"
else:
ddp_sampler = DDPSamplerWrapper(
sampler=loader.batch_sampler, shuffle_samples_between_gpus=False, pad_data_to_num_gpus=True
)
patched_loader = DataLoader(dataset=loader.dataset, batch_sampler=ddp_sampler, **kwargs_loader)
sampler_info = f"'{loader.batch_sampler.__class__.__name__}' batch sampler"
logging.info(f"DataLoader with {sampler_info} is updated to DDP mode")
return patched_loader
else:
warnings.warn(patch_dataloader_to_ddp.__name__, WarningDDP)
return loader
def check_loaders_is_patched(loaders: Union[DataLoader, Sequence[DataLoader]]) -> bool:
loaders = [loaders] if isinstance(loaders, DataLoader) else loaders
for loader in loaders:
if not any(isinstance(sampler, DDPSamplerWrapper) for sampler in [loader.batch_sampler, loader.sampler]):
return False
return True
__all__ = [
"DDPSamplerWrapper",
"patch_dataloader_to_ddp",
"extract_loader_parameters",
"check_loaders_is_patched",
]