common.sample

Sample and SampleList are data structures for arbitary data returned from a dataset. To work with Pythia, minimum requirement for datasets is to return an object of Sample class and for models to accept an object of type SampleList as an argument.

Sample is used to represent an arbitary sample from dataset, while SampleList is list of Sample combined in an efficient way to be used by the model. In simple term, SampleList is a batch of Sample but allow easy access of attributes from Sample while taking care of properly batching things.

class pythia.common.sample.Sample(init_dict={})[source]

Sample represent some arbitary data. All datasets in Pythia must return an object of type Sample.

Parameters:init_dict (Dict) – Dictionary to init Sample class with.

Usage:

>>> sample = Sample({"text": torch.tensor(2)})
>>> sample.text.zero_()
# Custom attributes can be added to ``Sample`` after initialization
>>> sample.context = torch.tensor(4)
fields()[source]

Get current attributes/fields registered under the sample.

Returns:Attributes registered under the Sample.
Return type:List[str]
class pythia.common.sample.SampleList(samples=[])[source]

SampleList is used to collate a list of Sample into a batch during batch preparation. It can be thought of as a merger of list of Dicts into a single Dict.

If Sample contains an attribute ‘text’ of size (2) and there are 10 samples in list, the returned SampleList will have an attribute ‘text’ which is a tensor of size (10, 2).

Parameters:samples (type) – List of Sample from which the SampleList will be created.

Usage:

>>> sample_list = [Sample({"text": torch.tensor(2)}), Sample({"text": torch.tensor(2)})]
>>> sample_list.text
torch.tensor([2, 2])
add_field(field, data)[source]

Add an attribute field with value data to the SampleList

Parameters:
  • field (str) – Key under which the data will be added.
  • data (object) – Data to be added, can be a torch.Tensor, list or Sample
copy()[source]

Get a copy of the current SampleList

Returns:Copy of current SampleList.
Return type:SampleList
fields()[source]

Get current attributes/fields registered under the SampleList.

Returns:list of attributes of the SampleList.
Return type:List[str]
get_batch_size()[source]

Get batch size of the current SampleList. There must be a tensor field present in the SampleList currently.

Returns:Size of the batch in SampleList.
Return type:int
get_field(field)[source]

Get value of a particular attribute

Parameters:field (str) – Attribute whose value is to be returned.
get_fields(fields)[source]

Get a new SampleList generated from the current SampleList but contains only the attributes passed in fields argument

Parameters:fields (List[str]) – Attributes whose SampleList will be made.
Returns:SampleList containing only the attribute values of the fields which were passed.
Return type:SampleList
get_item_list(key)[source]

Get SampleList of only one particular attribute that is present in the SampleList.

Parameters:key (str) – Attribute whose SampleList will be made.
Returns:SampleList containing only the attribute value of the key which was passed.
Return type:SampleList
to(device, non_blocking=True)[source]

Similar to .to function on a torch.Tensor. Moves all of the tensors present inside the SampleList to a particular device. If an attribute’s value is not a tensor, it is ignored and kept as it is.

Parameters:
  • device (str|torch.device) – Device on which the SampleList should moved.
  • non_blocking (bool) – Whether the move should be non_blocking. Default: True
Returns:

a SampleList moved to the device.

Return type:

SampleList