# BSD 2-Clause License
#
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import typing as t
import numpy as np
import torch
from smartredis import Client, Dataset
from smartsim.ml.data import DataDownloader
class _TorchDataGenerationCommon(DataDownloader, torch.utils.data.IterableDataset):
def __init__(self, **kwargs: t.Any) -> None:
init_samples = kwargs.pop("init_samples", False)
kwargs["init_samples"] = False
super().__init__(**kwargs)
if init_samples:
self.log(
"PyTorch Data Generator has to be created with "
"init_samples=False. Setting it to False automatically."
)
def _add_samples(self, indices: t.List[int]) -> None:
if self.client is None:
client = Client(self.cluster, self.address)
else:
client = self.client
datasets: t.List[Dataset] = []
if self.num_replicas == 1:
datasets = client.get_dataset_list_range(
self.list_name, start_index=indices[0], end_index=indices[-1]
)
else:
for idx in indices:
datasets += client.get_dataset_list_range(
self.list_name, start_index=idx, end_index=idx
)
if self.samples is None:
self.samples = torch.tensor(datasets[0].get_tensor(self.sample_name))
if self.need_targets:
self.targets = torch.tensor(datasets[0].get_tensor(self.target_name))
if len(datasets) > 1:
datasets = datasets[1:]
for dataset in datasets:
self.samples = torch.cat(
(self.samples, torch.tensor(dataset.get_tensor(self.sample_name)))
)
if self.need_targets:
self.targets = torch.cat(
(self.targets, torch.tensor(dataset.get_tensor(self.target_name)))
)
if self.samples is not None:
self.num_samples = self.samples.shape[0]
self.indices = np.arange(self.num_samples)
self.log(f"New dataset size: {self.num_samples}, batches: {len(self)}")
[docs]class StaticDataGenerator(_TorchDataGenerationCommon):
"""A class to download a dataset from the DB.
Details about parameters and features of this class can be found
in the documentation of ``DataDownloader``, of which it is just
a PyTorch-specialized sub-class with dynamic=False and init_samples=False.
When used in the DataLoader defined in this class, samples are initialized
automatically before training. Other data loaders using this generator
should implement the same behavior.
"""
def __init__(self, **kwargs: t.Any) -> None:
dynamic = kwargs.pop("dynamic", False)
kwargs["dynamic"] = False
super().__init__(**kwargs)
if dynamic:
self.log(
"Static data generator cannot be started "
"with dynamic=True, setting it to False"
)
[docs]class DynamicDataGenerator(_TorchDataGenerationCommon):
"""A class to download batches from the DB.
Details about parameters and features of this class can be found
in the documentation of ``DataDownloader``, of which it is just
a PyTorch-specialized sub-class with dynamic=True and init_samples=False.
When used in the DataLoader defined in this class, samples are initialized
automatically before training. Other data loaders using this generator
should implement the same behavior.
"""
def __init__(self, **kwargs: t.Any) -> None:
dynamic = kwargs.pop("dynamic", True)
kwargs["dynamic"] = True
super().__init__(**kwargs)
if not dynamic:
self.log(
"Dynamic data generator cannot be started with dynamic=False, "
"setting it to True"
)
def _worker_init_fn(worker_id: int) -> None:
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
worker_id = worker_info.id
num_workers = worker_info.num_workers
dataset.set_replica_parameters(
replica_rank=dataset.replica_rank * num_workers + worker_id,
num_replicas=dataset.num_replicas * num_workers,
)
dataset.log(
f"Worker {worker_id+1}/{num_workers}: dataset replica "
f"{dataset.replica_rank+1}/{dataset.num_replicas}"
)
dataset.init_samples()
[docs]class DataLoader(torch.utils.data.DataLoader): # pragma: no cover
"""DataLoader to be used as a wrapper of StaticDataGenerator or DynamicDataGenerator
This is just a sub-class of ``torch.utils.data.DataLoader`` which
sets up sources of a data generator correctly. DataLoader parameters
such as `num_workers` can be passed at initialization. `batch_size` should always
be set to None.
"""
def __init__(self, dataset: _TorchDataGenerationCommon, **kwargs: t.Any) -> None:
super().__init__(
dataset,
worker_init_fn=_worker_init_fn,
persistent_workers=True,
**kwargs,
)