Skip to content

Adding Backbones

Backbones are pre-trained foundation models.

Foundation models are essential to modern ML but are often difficult to work with. Design decisions made during pre-training (tokenization, architecture, io format) cannot be changed. At best, this results in many reimplementations for benchmarking or finetuning tasks, and a high risk of buggy code. At worst, these decisions can lock in users and exclude certain tasks and use-cases.

AIDO.ModelGenerator eliminates the need for reimplementation and makes backbones task-agnostic: wrap your backbone in a standard interface, and reuse it across all inference and finetuning tasks. It also makes compatibility transparent: if a backbone fits the required interface, it can be used for any data-appropriate task.

Note: Backbones for 1D sequence modeling are univerally supported. Other types of backbones included in AIDO.ModelGenerator (e.g. structure, image) are not yet universally supported, but will be in the future.

Available Backbones:

  • DNA: aido_dna_7b, aido_dna_300m, aido_dna_dummy, aido_dna_debug, dna_onehot
  • RNA: aido_rna_1b600m, aido_rna_1b600m_cds, aido_rna_650m, aido_rna_650m_cds, aido_rna_300m_mars, aido_rna_25m_mars, aido_rna_1m_mars, aido_dna_dummy, aido_dna_debug, dna_onehot
  • Protein: aido_protein_16b, aido_protein_16b_v1, aido_protein2structoken_16b, aido_protein_debug, protein_onehot
  • Cell (gene expression): aido_cell_100m, aido_cell_10m, aido_cell_3m
  • OneHot: dummy model, only tokenizes, useful for non-FM baselines and quick tests

At their core, backbones are PyTorch nn.Module objects with a few extra interfaces. To implement a new backbone, subclass a backbone interface and implement the required methods.

modelgenerator.backbones.SequenceBackboneInterface

Bases: Module

Interface class to ensure consistent implementation of essential methods for all backbones.

Source code in modelgenerator/backbones/base.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class SequenceBackboneInterface(nn.Module):
    """Interface class to ensure consistent implementation of essential methods for all backbones."""

    def forward(
        self, input_ids: Tensor, attention_mask: Tensor, all_hidden_states: bool = False
    ) -> Union[Tensor, List[Tensor]]:
        """Defines the forward pass for the model.

        Args:
            input_ids (Tensor): Token IDs (n, seq_len).
            attention_mask (Tensor): Attention mask (n, seq_len).
            all_hidden_states (bool, optional): Whether to return all hidden states. Defaults to False.

        Returns:
            Union[Tensor, List[Tensor]]: Model output, typically the last hidden state or logits.
        """
        raise NotImplementedError

    def get_decoder(self) -> nn.Module:
        """Returns the decoder module for the model, if applicable.

        Returns:
            nn.Module: The decoder module.
        """
        raise NotImplementedError

    def tokenize(
        self,
        sequences: List[str],
        padding: bool = True,
        add_special_tokens: bool = True,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Tokenizes input sequences into input IDs and attention masks.

        Args:
            sequences (List[str]): List of input sequences.
            padding (bool, optional): Whether to pad sequences. Defaults to True.
            add_special_tokens (bool, optional): Whether to add special tokens. Defaults to True.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Token IDs, attention masks, and special tokens mask.
        """
        raise NotImplementedError

    def decode_tokens(self, tokenized_sequences: Tensor) -> List[str]:
        """Decodes tokenized sequences back to text.

        Args:
            tokenized_sequences (Tensor): Tokenized sequences.

        Returns:
            List[str]: Decoded text sequences.
        """
        raise NotImplementedError

    def get_token_id(self, token: str) -> int:
        """Gets the ID of a specific token.

        Args:
            token (str): The token to look up.

        Returns:
            int: Token ID.
        """
        raise NotImplementedError

    def get_max_context(self) -> int:
        """Gets the maximum context length of the model.

        Returns:
            int: Maximum context length.
        """
        raise NotImplementedError

    def get_embedding_size(self) -> int:
        """Gets the embedding size of the model.

        Returns:
            int: Embedding size.
        """
        raise NotImplementedError

    def get_vocab_size(self) -> int:
        """Gets the vocabulary size of the model.

        Returns:
            int: Vocabulary size.
        """
        raise NotImplementedError

    def on_save_checkpoint(self, checkpoint: dict):
        """Handles checkpoint saving logic for the model.

        Args:
            checkpoint (dict): The checkpoint dictionary.
        """
        raise NotImplementedError

    def get_num_layer(self) -> int:
        """Gets the number of layers in the model.

        Returns:
            int: Number of layers.
        """
        raise NotImplementedError

decode_tokens(tokenized_sequences)

Decodes tokenized sequences back to text.

Parameters:

Name Type Description Default
tokenized_sequences Tensor

Tokenized sequences.

required

Returns:

Type Description
List[str]

List[str]: Decoded text sequences.

Source code in modelgenerator/backbones/base.py
82
83
84
85
86
87
88
89
90
91
def decode_tokens(self, tokenized_sequences: Tensor) -> List[str]:
    """Decodes tokenized sequences back to text.

    Args:
        tokenized_sequences (Tensor): Tokenized sequences.

    Returns:
        List[str]: Decoded text sequences.
    """
    raise NotImplementedError

forward(input_ids, attention_mask, all_hidden_states=False)

Defines the forward pass for the model.

Parameters:

Name Type Description Default
input_ids Tensor

Token IDs (n, seq_len).

required
attention_mask Tensor

Attention mask (n, seq_len).

required
all_hidden_states bool

Whether to return all hidden states. Defaults to False.

False

Returns:

Type Description
Union[Tensor, List[Tensor]]

Union[Tensor, List[Tensor]]: Model output, typically the last hidden state or logits.

Source code in modelgenerator/backbones/base.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def forward(
    self, input_ids: Tensor, attention_mask: Tensor, all_hidden_states: bool = False
) -> Union[Tensor, List[Tensor]]:
    """Defines the forward pass for the model.

    Args:
        input_ids (Tensor): Token IDs (n, seq_len).
        attention_mask (Tensor): Attention mask (n, seq_len).
        all_hidden_states (bool, optional): Whether to return all hidden states. Defaults to False.

    Returns:
        Union[Tensor, List[Tensor]]: Model output, typically the last hidden state or logits.
    """
    raise NotImplementedError

get_decoder()

Returns the decoder module for the model, if applicable.

Returns:

Type Description
Module

nn.Module: The decoder module.

Source code in modelgenerator/backbones/base.py
56
57
58
59
60
61
62
def get_decoder(self) -> nn.Module:
    """Returns the decoder module for the model, if applicable.

    Returns:
        nn.Module: The decoder module.
    """
    raise NotImplementedError

get_embedding_size()

Gets the embedding size of the model.

Returns:

Name Type Description
int int

Embedding size.

Source code in modelgenerator/backbones/base.py
112
113
114
115
116
117
118
def get_embedding_size(self) -> int:
    """Gets the embedding size of the model.

    Returns:
        int: Embedding size.
    """
    raise NotImplementedError

get_max_context()

Gets the maximum context length of the model.

Returns:

Name Type Description
int int

Maximum context length.

Source code in modelgenerator/backbones/base.py
104
105
106
107
108
109
110
def get_max_context(self) -> int:
    """Gets the maximum context length of the model.

    Returns:
        int: Maximum context length.
    """
    raise NotImplementedError

get_num_layer()

Gets the number of layers in the model.

Returns:

Name Type Description
int int

Number of layers.

Source code in modelgenerator/backbones/base.py
136
137
138
139
140
141
142
def get_num_layer(self) -> int:
    """Gets the number of layers in the model.

    Returns:
        int: Number of layers.
    """
    raise NotImplementedError

get_token_id(token)

Gets the ID of a specific token.

Parameters:

Name Type Description Default
token str

The token to look up.

required

Returns:

Name Type Description
int int

Token ID.

Source code in modelgenerator/backbones/base.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
def get_token_id(self, token: str) -> int:
    """Gets the ID of a specific token.

    Args:
        token (str): The token to look up.

    Returns:
        int: Token ID.
    """
    raise NotImplementedError

get_vocab_size()

Gets the vocabulary size of the model.

Returns:

Name Type Description
int int

Vocabulary size.

Source code in modelgenerator/backbones/base.py
120
121
122
123
124
125
126
def get_vocab_size(self) -> int:
    """Gets the vocabulary size of the model.

    Returns:
        int: Vocabulary size.
    """
    raise NotImplementedError

on_save_checkpoint(checkpoint)

Handles checkpoint saving logic for the model.

Parameters:

Name Type Description Default
checkpoint dict

The checkpoint dictionary.

required
Source code in modelgenerator/backbones/base.py
128
129
130
131
132
133
134
def on_save_checkpoint(self, checkpoint: dict):
    """Handles checkpoint saving logic for the model.

    Args:
        checkpoint (dict): The checkpoint dictionary.
    """
    raise NotImplementedError

tokenize(sequences, padding=True, add_special_tokens=True)

Tokenizes input sequences into input IDs and attention masks.

Parameters:

Name Type Description Default
sequences List[str]

List of input sequences.

required
padding bool

Whether to pad sequences. Defaults to True.

True
add_special_tokens bool

Whether to add special tokens. Defaults to True.

True

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Tuple[Tensor, Tensor, Tensor]: Token IDs, attention masks, and special tokens mask.

Source code in modelgenerator/backbones/base.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def tokenize(
    self,
    sequences: List[str],
    padding: bool = True,
    add_special_tokens: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
    """Tokenizes input sequences into input IDs and attention masks.

    Args:
        sequences (List[str]): List of input sequences.
        padding (bool, optional): Whether to pad sequences. Defaults to True.
        add_special_tokens (bool, optional): Whether to add special tokens. Defaults to True.

    Returns:
        Tuple[Tensor, Tensor, Tensor]: Token IDs, attention masks, and special tokens mask.
    """
    raise NotImplementedError