tasks.base_dataset

class pythia.tasks.base_dataset.BaseDataset(name, dataset_type, config={})[source]

Base class for implementing a dataset. Inherits from PyTorch’s Dataset class but adds some custom functionality on top. Instead of __getitem__ you have to implement get_item here. Processors mentioned in the configuration are automatically initialized for the end user.

Parameters:
  • name (str) – Name of your dataset to be used a representative in text strings
  • dataset_type (str) – Type of your dataset. Normally, train|val|test
  • config (ConfigNode) – Configuration for the current dataset
get_item(idx)[source]

Basically, __getitem__ of a torch dataset.

Parameters:idx (int) – Index of the sample to be loaded.
load_item(idx)[source]

Implement if you need to separately load the item and cache it.

Parameters:idx (int) – Index of the sample to be loaded.
prepare_batch(batch)[source]

Can be possibly overriden in your child class

Prepare batch for passing to model. Whatever returned from here will be directly passed to model’s forward function. Currently moves the batch to proper device.

Parameters:batch (SampleList) – sample list containing the currently loaded batch
Returns:Returns a sample representing current batch loaded
Return type:sample_list (SampleList)