models.base_model

Models built on top of Pythia need to inherit BaseModel class and adhere to some format. To create a model for Pythia, follow this quick cheatsheet.

  1. Inherit BaseModel class, make sure to call super().__init__() in your class’s __init__ function.
  2. Implement build function for your model. If you build everything in __init__, you can just return in this function.
  3. Write a forward function which takes in a SampleList as an argument and returns a dict.
  4. Register using @registry.register_model("key") decorator on top of the class.

If you are doing logits based predictions, the dict you return from your model should contain a scores field. Losses and Metrics are automatically calculated by the BaseModel class and added to this dict if not present.

Example:

import torch

from pythia.common.registry import registry
from pythia.models.base_model import BaseModel


@registry.register("pythia")
class Pythia(BaseModel):
    # config is model_attributes from global config
    def __init__(self, config):
        super().__init__(config)

    def build(self):
        ....

    def forward(self, sample_list):
        scores = torch.rand(sample_list.get_batch_size(), 3127)
        return {"scores": scores}
class pythia.models.base_model.BaseModel(config)[source]

For integration with Pythia’s trainer, datasets and other feautures, models needs to inherit this class, call super, write a build function, write a forward function taking a SampleList as input and returning a dict as output and finally, register it using @registry.register_model

Parameters:config (ConfigNode) – model_attributes configuration from global config.
build()[source]

Function to be implemented by the child class, in case they need to build their model separately than __init__. All model related downloads should also happen here.

forward(sample_list, *args, **kwargs)[source]

To be implemented by child class. Takes in a SampleList and returns back a dict.

Parameters:
  • sample_list (SampleList) – SampleList returned by the DataLoader for
  • iteration (current) –
Returns:

Dict containing scores object.

Return type:

Dict

init_losses_and_metrics()[source]

Initializes loss and metrics for the model based losses key and metrics keys. Automatically called by Pythia internally after building the model.