Skip to content

Adding Backbones

Backbones are pre-trained foundation models.

Foundation models are essential to modern ML but are often difficult to work with. Design decisions made during pre-training (tokenization, architecture, io format) cannot be changed. At best, this results in many reimplementations for benchmarking or finetuning tasks, and a high risk of buggy code. At worst, these decisions can lock in users and exclude certain tasks and use-cases.

AIDO.ModelGenerator eliminates the need for reimplementation and makes backbones task-agnostic: wrap your backbone in a standard interface, and reuse it across all inference and finetuning tasks. It also makes compatibility transparent: if a backbone fits the required interface, it can be used for any data-appropriate task.

Note: Backbones for 1D sequence modeling are univerally supported. Other types of backbones included in AIDO.ModelGenerator (e.g. structure, image) are not yet universally supported, but will be in the future.

Available Backbones:

  • DNA: aido_dna_7b, aido_dna_300m, aido_dna_dummy, aido_dna_debug, dna_onehot
  • RNA: aido_rna_1b600m, aido_rna_1b600m_cds, aido_rna_650m, aido_rna_650m_cds, aido_rna_300m_mars, aido_rna_25m_mars, aido_rna_1m_mars, aido_dna_dummy, aido_dna_debug, dna_onehot
  • Protein: aido_protein_16b, aido_protein_16b_v1, aido_protein2structoken_16b, aido_protein_debug, protein_onehot, aido_protein_rag_16b, aido_protein_rag_3b
  • Cell (gene expression): aido_cell_100m, aido_cell_10m, aido_cell_3m
  • OneHot: dummy model, only tokenizes, useful for non-FM baselines and quick tests

At their core, backbones are PyTorch nn.Module objects with a few extra interfaces. To implement a new backbone, subclass a backbone interface and implement the required methods.

modelgenerator.backbones.SequenceBackboneInterface

Bases: Module

Interface class to ensure consistent implementation of essential methods for all backbones.

Parameters:

Name Type Description Default
*args

The description is missing.

required
**kwargs

The description is missing.

required

Attributes:

Name Type Description
fsdp_wrap_modules List[str]

List of module paths to wrap when using distributed training with FSDP.

model_path str

Path to the model weights. May be HF.

Source code in modelgenerator/backbones/base.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class SequenceBackboneInterface(nn.Module, metaclass=GoogleDocstringInheritanceInitMeta):
    """Interface class to ensure consistent implementation of essential methods for all backbones.

    Attributes:
        fsdp_wrap_modules: List of module paths to wrap when using distributed training with FSDP.
        model_path (str): Path to the model weights. May be HF.
    """

    # import paths of modules to wrap when using FSDP
    fsdp_wrap_modules: List[str] = []
    model_path: str = ""

    def forward(
        self, input_ids: Tensor, attention_mask: Tensor, all_hidden_states: bool = False
    ) -> Union[Tensor, List[Tensor]]:
        """Defines the forward pass for the model.

        Args:
            input_ids (Tensor): Token IDs (n, seq_len).
            attention_mask (Tensor): Attention mask (n, seq_len).
            all_hidden_states (bool, optional): Whether to return all hidden states. Defaults to False.

        Returns:
            Union[Tensor, List[Tensor]]: Model output, typically the last hidden state or logits.
        """
        raise NotImplementedError

    def get_decoder(self) -> nn.Module:
        """Returns the decoder module for the model, if applicable.

        Returns:
            nn.Module: The decoder module.
        """
        raise NotImplementedError

    def tokenize(
        self,
        sequences: List[str],
        padding: bool = True,
        add_special_tokens: bool = True,
        **kwargs,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Tokenizes input sequences into input IDs and attention masks.

        Args:
            sequences (List[str]): List of input sequences.
            padding (bool, optional): Whether to pad sequences. Defaults to True.
            add_special_tokens (bool, optional): Whether to add special tokens. Defaults to True.

        Returns:
            dict: A dictionary containing input_ids.
        """
        raise NotImplementedError

    def decode_tokens(self, tokenized_sequences: Tensor) -> List[str]:
        """Decodes tokenized sequences back to text.

        Args:
            tokenized_sequences (Tensor): Tokenized sequences.

        Returns:
            List[str]: Decoded text sequences.
        """
        raise NotImplementedError

    def get_token_id(self, token: str) -> int:
        """Gets the ID of a specific token.

        Args:
            token (str): The token to look up.

        Returns:
            int: Token ID.
        """
        raise NotImplementedError

    def get_max_context(self) -> int:
        """Gets the maximum context length of the model.

        Returns:
            int: Maximum context length.
        """
        raise NotImplementedError

    def get_embedding_size(self) -> int:
        """Gets the embedding size of the model.

        Returns:
            int: Embedding size.
        """
        raise NotImplementedError

    def get_vocab_size(self) -> int:
        """Gets the vocabulary size of the model.

        Returns:
            int: Vocabulary size.
        """
        raise NotImplementedError

    def on_save_checkpoint(self, checkpoint: dict):
        """Handles checkpoint saving logic for the model.

        Args:
            checkpoint (dict): The checkpoint dictionary.
        """
        raise NotImplementedError

    def get_num_layer(self) -> int:
        """Gets the number of layers in the model.

        Returns:
            int: Number of layers.
        """
        raise NotImplementedError