Source code for pythia.tasks.base_task

# Copyright (c) Facebook, Inc. and its affiliates.
"""
Tasks come above datasets in hierarchy level. In case you want to
implement a new task, you need to inherit ``BaseTask`` class. You need
to implement ``_get_available_datasets`` and ``_preprocess_item`` functions
to complete the implementation. You can check the source to see if you need
to override any other methods like ``prepare_batch``.

Check example of ``VQATask`` here_.

Example::

    from pythia.tasks.base_task import BaseTask
    from pythia.common.registry import registry


    @registry.register_task("my")
    class MyTask(BaseTask):
        def __init__(self):
            super().__init__("my")

        def _get_available_datasets(self):
            return ["my"]

        def _preprocess_item(self):
            item.text = None
            return item

.. _here: https://github.com/facebookresearch/pythia/blob/v0.3/pythia/tasks/vqa/vqa_task.py
"""

import sys

import numpy as np
from torch.utils.data import Dataset

from pythia.common.registry import registry
from pythia.utils.distributed_utils import synchronize, is_main_process


[docs]class BaseTask(Dataset): """ BaseTask that task classes need to inherit in order to create a new task. Users must implement ``_get_available_datasets`` and ``_preprocess_item`` in order to complete implementation. Args: task_name (str): Name of the task with which it will be registered """ def __init__(self, task_name): super(BaseTask, self).__init__() self.task_name = task_name self.writer = registry.get("writer") def _process_datasets(self): if "datasets" not in self.opts: self.writer.write( "No datasets attribute present for task: %s." " Defaulting to all" % (self.task_name), "warning", ) datasets = "all" else: datasets = self.opts["datasets"] if datasets is None or datasets == "all": datasets = self._get_available_datasets() if type(datasets) == str: datasets = list(map(lambda x: x.strip(), datasets.split(","))) if len(datasets) == 0 and datasets[0] == "all": datasets = self._get_available_datasets() self.given_datasets = datasets def load(self, **opts): self.opts = opts self._process_datasets() self.datasets = [] self.builders = [] available_datasets = self._get_available_datasets() self.total_length = 0 self.per_dataset_lengths = [] self.num_datasets = 0 for dataset in self.given_datasets: if dataset in available_datasets: builder_class = registry.get_builder_class(dataset) if builder_class is None: print("No builder class found for %s." % dataset) continue builder_instance = builder_class() if dataset in self.opts["dataset_attributes"]: attributes = self.opts["dataset_attributes"][dataset] else: self.writer.write( "Dataset %s is missing from " "dataset_attributes in config." % dataset, "error", ) sys.exit(1) dataset_type = self.opts.get("dataset_type", "train") builder_instance.build(dataset_type, attributes) dataset_instance = builder_instance.load(dataset_type, attributes) if dataset_instance is None: continue self.builders.append(builder_instance) self.datasets.append(dataset_instance) self.per_dataset_lengths.append(len(dataset_instance)) self.total_length += len(dataset_instance) else: print( "Dataset %s is not a valid dataset for task %s. Skipping" % (dataset, self.task_name) ) self.num_datasets = len(self.datasets) self.dataset_probablities = [1 for _ in range(self.num_datasets)] sampling = self.opts.get("dataset_size_proportional_sampling", None) if sampling is True: self.dataset_probablities = self.per_dataset_lengths[:] self.dataset_probablities = [ prob / self.total_length for prob in self.dataset_probablities ] self.change_dataset()
[docs] def _get_available_datasets(self): """Set available datasets for this task here. Override in your child task class Temporary solution, later we will use decorators to easily register datasets with a task Returns: List - List of available datasets for this particular task """ return []
def get_datasets(self): return self.datasets def __len__(self): return self.total_length def __getitem__(self, idx): idx = idx % self.per_dataset_lengths[self.dataset_choice] item = self.chosen_dataset[idx] return self._preprocess_item(item) def change_dataset(self): self.dataset_choice = np.random.choice( self.num_datasets, 1, p=self.dataset_probablities )[0] self.chosen_dataset = self.datasets[self.dataset_choice] def verbose_dump(self, *args, **kwargs): self.chosen_dataset.verbose_dump(*args, **kwargs) def prepare_batch(self, batch): return self.chosen_dataset.prepare_batch(batch)
[docs] def _preprocess_item(self, item): """Preprocess an item to be returned from __getitem__. Override in your child task class, so you have control on what you are returning Args: item (Sample): Sample returned by a particular dataset Returns: Sample: Preprocessed item """ raise NotImplementedError( "This task doesn't implement preprocess_item" " method" )
[docs] def update_registry_for_model(self, config): """ Use this if there is some specific configuration required by model which must be inferred at runtime. """ for builder in self.builders: builder.update_registry_for_model(config)
def init_args(self, parser): parser.add_argument_group("General Task Arguments") parser.add_argument( "-dsp", "--dataset_size_proportional_sampling", type=bool, default=0, help="Pass if you want to sample from" " dataset according to its size. Default: Equal " " weighted sampling", ) # TODO: Figure out later if we want to init args from datasets # self._init_args(parser)
[docs] def _init_args(self, parser): """Override this function to add extra parameters to parser in your child task class. Parameters ---------- parser : ArgumentParser Original parser object passed from the higher level classes like trainer Returns ------- type Description of returned object. """ for builder in self.builders: builder.init_args(parser)
[docs] def clean_config(self, config): """ Override this in case you want to clean the config you updated earlier in update_registry_for_model """ return config