# Copyright (c) Facebook, Inc. and its affiliates.
"""
``Sample`` and ``SampleList`` are data structures for arbitary data returned from a
dataset. To work with Pythia, minimum requirement for datasets is to return
an object of ``Sample`` class and for models to accept an object of type `SampleList`
as an argument.
``Sample`` is used to represent an arbitary sample from dataset, while ``SampleList``
is list of Sample combined in an efficient way to be used by the model.
In simple term, ``SampleList`` is a batch of Sample but allow easy access of
attributes from ``Sample`` while taking care of properly batching things.
"""
import collections
from collections import OrderedDict
from copy import deepcopy
import torch
[docs]class Sample(OrderedDict):
"""Sample represent some arbitary data. All datasets in Pythia must
return an object of type ``Sample``.
Args:
init_dict (Dict): Dictionary to init ``Sample`` class with.
Usage::
>>> sample = Sample({"text": torch.tensor(2)})
>>> sample.text.zero_()
# Custom attributes can be added to ``Sample`` after initialization
>>> sample.context = torch.tensor(4)
"""
def __init__(self, init_dict={}):
super().__init__(init_dict)
def __setattr__(self, key, value):
self[key] = value
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)
[docs] def fields(self):
"""Get current attributes/fields registered under the sample.
Returns:
List[str]: Attributes registered under the Sample.
"""
return list(self.keys())
[docs]class SampleList(OrderedDict):
"""``SampleList`` is used to collate a list of ``Sample`` into a batch during batch
preparation. It can be thought of as a merger of list of Dicts into a single Dict.
If ``Sample`` contains an attribute 'text' of size (2) and there are 10 samples in
list, the returned ``SampleList`` will have an attribute 'text' which is a tensor
of size (10, 2).
Args:
samples (type): List of ``Sample`` from which the ``SampleList`` will be created.
Usage::
>>> sample_list = [Sample({"text": torch.tensor(2)}), Sample({"text": torch.tensor(2)})]
>>> sample_list.text
torch.tensor([2, 2])
"""
_TENSOR_FIELD_ = "_tensor_field"
def __init__(self, samples=[]):
super().__init__(self)
if len(samples) == 0:
return
if self._check_and_load_dict(samples):
return
# If passed sample list was in form of key, value pairs of tuples
# return after loading these
if self._check_and_load_tuple(samples):
return
fields = samples[0].keys()
for field in fields:
if isinstance(samples[0][field], torch.Tensor):
size = (len(samples), *samples[0][field].size())
self[field] = samples[0][field].new_empty(size)
if self._get_tensor_field() is None:
self._set_tensor_field(field)
else:
self[field] = [None for _ in range(len(samples))]
for idx, sample in enumerate(samples):
# it should be a tensor but not a 0-d tensor
if (
isinstance(sample[field], torch.Tensor)
and len(sample[field].size()) != 0
and sample[field].size(0) != samples[0][field].size(0)
):
raise AssertionError(
"Fields for all samples must be equally sized. "
"{} is of different sizes".format(field)
)
self[field][idx] = self._get_data_copy(sample[field])
if isinstance(samples[0][field], collections.abc.Mapping):
self[field] = SampleList(self[field])
def _check_and_load_tuple(self, samples):
if isinstance(samples[0], (tuple, list)) and isinstance(samples[0][0], str):
for kv_pair in samples:
self.add_field(kv_pair[0], kv_pair[1])
return True
else:
return False
def _check_and_load_dict(self, samples):
if isinstance(samples, collections.abc.Mapping):
for key, value in samples.items():
self.add_field(key, value)
return True
else:
return False
def _fix_sample_type(self, samples):
if not isinstance(samples[0], Sample):
proper_samples = []
for sample in samples:
proper_samples.append(Sample(sample))
samples = proper_samples
return samples
def __getattr__(self, key):
if key not in self:
raise AttributeError(
"Key {} not found in the SampleList. "
"Valid choices are {}".format(key, self.fields())
)
fields = self.keys()
if key in fields:
return self[key]
sample = Sample()
for field in fields:
sample[field] = self[field][key]
return sample
[docs] def get_item_list(self, key):
"""Get ``SampleList`` of only one particular attribute that is present
in the ``SampleList``.
Args:
key (str): Attribute whose ``SampleList`` will be made.
Returns:
SampleList: SampleList containing only the attribute value of the key
which was passed.
"""
sample = self[key]
return SampleList([sample])
[docs] def copy(self):
"""Get a copy of the current SampleList
Returns:
SampleList: Copy of current SampleList.
"""
sample_list = SampleList()
fields = self.fields()
for field in fields:
sample_list.add_field(field, self[field])
return sample_list
[docs] def fields(self):
"""Get current attributes/fields registered under the SampleList.
Returns:
List[str]: list of attributes of the SampleList.
"""
return list(self.keys())
[docs] def get_fields(self, fields):
"""Get a new ``SampleList`` generated from the current ``SampleList``
but contains only the attributes passed in `fields` argument
Args:
fields (List[str]): Attributes whose ``SampleList`` will be made.
Returns:
SampleList: SampleList containing only the attribute values of the fields
which were passed.
"""
current_fields = self.fields()
return_list = SampleList()
for field in fields:
if field not in current_fields:
raise AttributeError(
"{} not present in SampleList. "
"Valid choices are {}".format(field, current_fields)
)
return_list.add_field(field, self[field])
return return_list
[docs] def get_field(self, field):
"""Get value of a particular attribute
Args:
field (str): Attribute whose value is to be returned.
"""
return self[field]
def _get_data_copy(self, data):
if isinstance(data, torch.Tensor):
copy_ = data.clone()
else:
copy_ = deepcopy(data)
return copy_
def _get_tensor_field(self):
return self.__dict__.get(SampleList._TENSOR_FIELD_, None)
def _set_tensor_field(self, value):
self.__dict__[SampleList._TENSOR_FIELD_] = value
[docs] def get_batch_size(self):
"""Get batch size of the current ``SampleList``. There must be a tensor
field present in the ``SampleList`` currently.
Returns:
int: Size of the batch in ``SampleList``.
"""
tensor_field = self._get_tensor_field()
assert tensor_field is not None, "There is no tensor yet in SampleList"
return self[tensor_field].size(0)
[docs] def add_field(self, field, data):
"""Add an attribute ``field`` with value ``data`` to the SampleList
Args:
field (str): Key under which the data will be added.
data (object): Data to be added, can be a ``torch.Tensor``, ``list``
or ``Sample``
"""
fields = self.fields()
tensor_field = self._get_tensor_field()
if len(fields) == 0:
self[field] = self._get_data_copy(data)
else:
if (
isinstance(data, torch.Tensor)
and len(data.size()) != 0
and tensor_field is not None
and data.size(0) != self[tensor_field].size(0)
):
raise AssertionError(
"A tensor field to be added must "
"have same size as existing tensor "
"fields in SampleList. "
"Passed size: {}, Required size: {}".format(
len(data), len(self[fields[0]])
)
)
self[field] = self._get_data_copy(data)
if isinstance(self[field], torch.Tensor) and tensor_field is None:
self._set_tensor_field(field)
[docs] def to(self, device, non_blocking=True):
"""Similar to ``.to`` function on a `torch.Tensor`. Moves all of the
tensors present inside the ``SampleList`` to a particular device. If an
attribute's value is not a tensor, it is ignored and kept as it is.
Args:
device (str|torch.device): Device on which the ``SampleList`` should
moved.
non_blocking (bool): Whether the move should be non_blocking. Default: True
Returns:
SampleList: a SampleList moved to the ``device``.
"""
fields = self.keys()
sample_list = self.copy()
if not isinstance(device, torch.device):
if not isinstance(device, str):
raise TypeError(
"device must be either 'str' or "
"'torch.device' type, {} found".format(type(device))
)
device = torch.device(device)
for field in fields:
if hasattr(sample_list[field], "to"):
sample_list[field] = sample_list[field].to(
device, non_blocking=non_blocking
)
return sample_list