Source code for pythia.models.base_model

# Copyright (c) Facebook, Inc. and its affiliates.
"""
Models built on top of Pythia need to inherit ``BaseModel`` class and adhere to
some format. To create a model for Pythia, follow this quick cheatsheet.

1. Inherit ``BaseModel`` class, make sure to call ``super().__init__()`` in your
   class's ``__init__`` function.
2. Implement `build` function for your model. If you build everything in ``__init__``,
   you can just return in this function.
3. Write a `forward` function which takes in a ``SampleList`` as an argument and
   returns a dict.
4. Register using ``@registry.register_model("key")`` decorator on top of the
   class.

If you are doing logits based predictions, the dict you return from your model
should contain a `scores` field. Losses and Metrics are automatically
calculated by the ``BaseModel`` class and added to this dict if not present.

Example::

    import torch

    from pythia.common.registry import registry
    from pythia.models.base_model import BaseModel


    @registry.register("pythia")
    class Pythia(BaseModel):
        # config is model_attributes from global config
        def __init__(self, config):
            super().__init__(config)

        def build(self):
            ....

        def forward(self, sample_list):
            scores = torch.rand(sample_list.get_batch_size(), 3127)
            return {"scores": scores}
"""


import collections
import warnings

from torch import nn

from pythia.common.registry import registry
from pythia.common.report import Report
from pythia.modules.losses import Losses
from pythia.modules.metrics import Metrics


[docs]class BaseModel(nn.Module): """For integration with Pythia's trainer, datasets and other feautures, models needs to inherit this class, call `super`, write a build function, write a forward function taking a ``SampleList`` as input and returning a dict as output and finally, register it using ``@registry.register_model`` Args: config (ConfigNode): ``model_attributes`` configuration from global config. """ def __init__(self, config): super().__init__() self.config = config self.writer = registry.get("writer")
[docs] def build(self): """Function to be implemented by the child class, in case they need to build their model separately than ``__init__``. All model related downloads should also happen here. """ raise NotImplementedError( "Build method not implemented in the child model class." )
[docs] def init_losses_and_metrics(self): """Initializes loss and metrics for the model based ``losses`` key and ``metrics`` keys. Automatically called by Pythia internally after building the model. """ losses = self.config.get("losses", []) metrics = self.config.get("metrics", []) if len(losses) == 0: warnings.warn( "No losses are defined in model configuration. You are expected " "to return loss in your return dict from forward." ) if len(metrics) == 0: warnings.warn( "No metrics are defined in model configuration. You are expected " "to return metrics in your return dict from forward." ) self.losses = Losses(losses) self.metrics = Metrics(metrics)
@classmethod def init_args(cls, parser): return parser
[docs] def forward(self, sample_list, *args, **kwargs): """To be implemented by child class. Takes in a ``SampleList`` and returns back a dict. Args: sample_list (SampleList): SampleList returned by the DataLoader for current iteration Returns: Dict: Dict containing scores object. """ raise NotImplementedError( "Forward of the child model class needs to be implemented." )
def __call__(self, sample_list, *args, **kwargs): model_output = super().__call__(sample_list, *args, **kwargs) # Make sure theat the output from the model is a Mapping assert isinstance(model_output, collections.abc.Mapping), ( "A dict must be returned from the forward of the model." ) if "losses" in model_output: warnings.warn( "'losses' already present in model output. " "No calculation will be done in base model." ) assert isinstance( model_output["losses"], collections.abc.Mapping ), "'losses' must be a dict." else: model_output["losses"] = self.losses(sample_list, model_output) if "metrics" in model_output: warnings.warn( "'metrics' already present in model output. " "No calculation will be done in base model." ) assert isinstance( model_output["metrics"], collections.abc.Mapping ), "'metrics' must be a dict." else: model_output["metrics"] = self.metrics(sample_list, model_output) return model_output