models.base_model¶
Models built on top of Pythia need to inherit BaseModel
class and adhere to
some format. To create a model for Pythia, follow this quick cheatsheet.
- Inherit
BaseModel
class, make sure to callsuper().__init__()
in your class’s__init__
function. - Implement build function for your model. If you build everything in
__init__
, you can just return in this function. - Write a forward function which takes in a
SampleList
as an argument and returns a dict. - Register using
@registry.register_model("key")
decorator on top of the class.
If you are doing logits based predictions, the dict you return from your model
should contain a scores field. Losses and Metrics are automatically
calculated by the BaseModel
class and added to this dict if not present.
Example:
import torch
from pythia.common.registry import registry
from pythia.models.base_model import BaseModel
@registry.register("pythia")
class Pythia(BaseModel):
# config is model_attributes from global config
def __init__(self, config):
super().__init__(config)
def build(self):
....
def forward(self, sample_list):
scores = torch.rand(sample_list.get_batch_size(), 3127)
return {"scores": scores}
-
class
pythia.models.base_model.
BaseModel
(config)[source]¶ For integration with Pythia’s trainer, datasets and other feautures, models needs to inherit this class, call super, write a build function, write a forward function taking a
SampleList
as input and returning a dict as output and finally, register it using@registry.register_model
Parameters: config (ConfigNode) – model_attributes
configuration from global config.-
build
()[source]¶ Function to be implemented by the child class, in case they need to build their model separately than
__init__
. All model related downloads should also happen here.
-
forward
(sample_list, *args, **kwargs)[source]¶ To be implemented by child class. Takes in a
SampleList
and returns back a dict.Parameters: - sample_list (SampleList) – SampleList returned by the DataLoader for
- iteration (current) –
Returns: Dict containing scores object.
Return type: Dict
-