tasks.base_dataset¶
-
class
pythia.tasks.base_dataset.
BaseDataset
(name, dataset_type, config={})[source]¶ Base class for implementing a dataset. Inherits from PyTorch’s Dataset class but adds some custom functionality on top. Instead of
__getitem__
you have to implementget_item
here. Processors mentioned in the configuration are automatically initialized for the end user.Parameters: - name (str) – Name of your dataset to be used a representative in text strings
- dataset_type (str) – Type of your dataset. Normally, train|val|test
- config (ConfigNode) – Configuration for the current dataset
-
get_item
(idx)[source]¶ Basically, __getitem__ of a torch dataset.
Parameters: idx (int) – Index of the sample to be loaded.
-
load_item
(idx)[source]¶ Implement if you need to separately load the item and cache it.
Parameters: idx (int) – Index of the sample to be loaded.
-
prepare_batch
(batch)[source]¶ Can be possibly overriden in your child class
Prepare batch for passing to model. Whatever returned from here will be directly passed to model’s forward function. Currently moves the batch to proper device.
Parameters: batch (SampleList) – sample list containing the currently loaded batch Returns: Returns a sample representing current batch loaded Return type: sample_list (SampleList)