Source code for pythia.common.registry

# Copyright (c) Facebook, Inc. and its affiliates.
"""
Registry is central source of truth in Pythia. Inspired from Redux's
concept of global store, Registry maintains mappings of various information
to unique keys. Special functions in registry can be used as decorators to
register different kind of classes.

Import the global registry object using

``from pythia.common.registry import registry``

Various decorators for registry different kind of classes with unique keys

- Register a task: ``@registry.register_task``
- Register a trainer: ``@registry.register_trainer``
- Register a dataset builder: ``@registry.register_builder``
- Register a metric: ``@registry.register_metric``
- Register a loss: ``@registry.register_loss``
- Register a model: ``@registry.register_model``
- Register a processor: ``@registry.register_processor``
- Register a optimizer: ``@registry.register_optimizer``
- Register a scheduler: ``@registry.register_scheduler``
"""


[docs]class Registry: r"""Class for registry object which acts as central source of truth for Pythia """ mapping = { # Mappings of task name to their respective classes # Use decorator "register_task" in pythia.common.decorators to regiter a # task class with a specific name # Further, use the name with the class is registered in the # command line or configuration to load that specific task "task_name_mapping": {}, # Similar to the task_name_mapping above except that this # one is used to keep a mapping for dataset to its builder class. # Use "register_builder" decorator to mapping a builder "trainer_name_mapping": {}, "builder_name_mapping": {}, "model_name_mapping": {}, "metric_name_mapping": {}, "loss_name_mapping": {}, "optimizer_name_mapping": {}, "scheduler_name_mapping": {}, "processor_name_mapping": {}, "state": {}, }
[docs] @classmethod def register_task(cls, name): r"""Register a task to registry with key 'name' Args: name: Key with which the task will be registered. Usage:: from pythia.common.registry import registry from pythia.tasks.base_task import BaseTask @registry.register_task("vqa") class VQATask(BaseTask): ... """ def wrap(task_cls): from pythia.tasks.base_task import BaseTask assert issubclass( task_cls, BaseTask ), "All task must inherit BaseTask class" cls.mapping["task_name_mapping"][name] = task_cls return task_cls return wrap
[docs] @classmethod def register_trainer(cls, name): r"""Register a trainer to registry with key 'name' Args: name: Key with which the trainer will be registered. Usage:: from pythia.common.registry import registry from pythia.trainers.custom_trainer import CustomTrainer @registry.register_trainer("custom_trainer") class CustomTrainer(): ... """ def wrap(trainer_cls): cls.mapping["trainer_name_mapping"][name] = trainer_cls return trainer_cls return wrap
[docs] @classmethod def register_builder(cls, name): r"""Register a dataset builder to registry with key 'name' Args: name: Key with which the metric will be registered. Usage:: from pythia.common.registry import registry from pythia.tasks.base_dataset_builder import BaseDatasetBuilder @registry.register_builder("vqa2") class VQA2Builder(BaseDatasetBuilder): ... """ def wrap(builder_cls): from pythia.tasks.base_dataset_builder import BaseDatasetBuilder assert issubclass( builder_cls, BaseDatasetBuilder ), "All builders must inherit BaseDatasetBuilder class" cls.mapping["builder_name_mapping"][name] = builder_cls return builder_cls return wrap
[docs] @classmethod def register_metric(cls, name): r"""Register a metric to registry with key 'name' Args: name: Key with which the metric will be registered. Usage:: from pythia.common.registry import registry from pythia.modules.metrics import BaseMetric @registry.register_metric("r@1") class RecallAt1(BaseMetric): ... """ def wrap(func): from pythia.modules.metrics import BaseMetric assert issubclass( func, BaseMetric ), "All Metric must inherit BaseMetric class" cls.mapping["metric_name_mapping"][name] = func return func return wrap
[docs] @classmethod def register_loss(cls, name): r"""Register a loss to registry with key 'name' Args: name: Key with which the loss will be registered. Usage:: from pythia.common.registry import registry from torch import nn @registry.register_task("logit_bce") class LogitBCE(nn.Module): ... """ def wrap(func): from torch import nn assert issubclass( func, nn.Module ), "All loss must inherit torch.nn.Module class" cls.mapping["loss_name_mapping"][name] = func return func return wrap
[docs] @classmethod def register_model(cls, name): r"""Register a model to registry with key 'name' Args: name: Key with which the model will be registered. Usage:: from pythia.common.registry import registry from pythia.models.base_model import BaseModel @registry.register_task("pythia") class Pythia(BaseModel): ... """ def wrap(func): from pythia.models.base_model import BaseModel assert issubclass( func, BaseModel ), "All models must inherit BaseModel class" cls.mapping["model_name_mapping"][name] = func return func return wrap
[docs] @classmethod def register_processor(cls, name): r"""Register a processor to registry with key 'name' Args: name: Key with which the processor will be registered. Usage:: from pythia.common.registry import registry from pythia.tasks.processors import BaseProcessor @registry.register_task("glove") class GloVe(BaseProcessor): ... """ def wrap(func): from pythia.tasks.processors import BaseProcessor assert issubclass( func, BaseProcessor ), "All Processor classes must inherit BaseProcessor class" cls.mapping["processor_name_mapping"][name] = func return func return wrap
@classmethod def register_optimizer(cls, name): def wrap(func): cls.mapping["optimizer_name_mapping"][name] = func return func return wrap @classmethod def register_scheduler(cls, name): def wrap(func): cls.mapping["scheduler_name_mapping"][name] = func return func return wrap
[docs] @classmethod def register(cls, name, obj): r"""Register an item to registry with key 'name' Args: name: Key with which the item will be registered. Usage:: from pythia.common.registry import registry registry.register("config", {}) """ path = name.split(".") current = cls.mapping["state"] for part in path[:-1]: if part not in current: current[part] = {} current = current[part] current[path[-1]] = obj
@classmethod def get_task_class(cls, name): return cls.mapping["task_name_mapping"].get(name, None) @classmethod def get_trainer_class(cls, name): return cls.mapping["trainer_name_mapping"].get(name, None) @classmethod def get_builder_class(cls, name): return cls.mapping["builder_name_mapping"].get(name, None) @classmethod def get_model_class(cls, name): return cls.mapping["model_name_mapping"].get(name, None) @classmethod def get_processor_class(cls, name): return cls.mapping["processor_name_mapping"].get(name, None) @classmethod def get_metric_class(cls, name): return cls.mapping["metric_name_mapping"].get(name, None) @classmethod def get_loss_class(cls, name): return cls.mapping["loss_name_mapping"].get(name, None) @classmethod def get_optimizer_class(cls, name): return cls.mapping["optimizer_name_mapping"].get(name, None) @classmethod def get_scheduler_class(cls, name): return cls.mapping["scheduler_name_mapping"].get(name, None)
[docs] @classmethod def get(cls, name, default=None, no_warning=False): r"""Get an item from registry with key 'name' Args: name (string): Key whose value needs to be retreived. default: If passed and key is not in registry, default value will be returned with a warning. Default: None no_warning (bool): If passed as True, warning when key doesn't exist will not be generated. Useful for Pythia's internal operations. Default: False Usage:: from pythia.common.registry import registry config = registry.get("config") """ original_name = name name = name.split(".") value = cls.mapping["state"] for subname in name: value = value.get(subname, default) if value is default: break if ( "writer" in cls.mapping["state"] and value == default and no_warning is False ): cls.mapping["state"]["writer"].write( "Key {} is not present in registry, returning default value " "of {}".format(original_name, default) ) return value
[docs] @classmethod def unregister(cls, name): r"""Remove an item from registry with key 'name' Args: name: Key which needs to be removed. Usage:: from pythia.common.registry import registry config = registry.unregister("config") """ return cls.mapping["state"].pop(name, None)
registry = Registry()