Skip to content

Adding Data Loaders

AIDO.ModelGenerator uses Lightning DataModules for dataset management and loading. We also provide a few tools to make data management more convenient, and work with common file types out-of-the-box.

AIDO.ModelGenerator provides a DataInterface class that hides boilerplate, along with a HFDatasetLoaderMixin that combines Lightning DataModule structure and HuggingFace Datasets convenience together to quickly load data from HuggingFace or common file formats (e.g. tsv, csv, json, etc). More convenient mixins and example usage are outlined below.

Many common tasks and data loaders are already implemented in AIDO.ModelGenerator, and only require setting new paths to run with new data. See the Data API Reference for all types of available data modules.

modelgenerator.data.DataInterface

Bases: LightningDataModule, KFoldMixin

Base class for all data modules in this project. Handles the boilerplate of setting up data loaders.

Note

Subclasses must implement the setup method. All datasets should return a dictionary of data items. To use HF loading, add the HFDatasetLoaderMixin. For any task-specific behaviors, implement transformations using torch.utils.data.Dataset objects. See MLM for an example.

Parameters:

Name Type Description Default
path str

Path to the dataset, can be one of: 1. a local path to a data folder 2. a Huggingface dataset identifier

required
config_name str

Defining the name of the dataset configuration. it affects how the dataset is loaded.

None
train_split_name str

The name of the training split. Defaults to "train".

'train'
test_split_name str

The name of the test split. Defaults to "test".

'test'
valid_split_name str

The name of the validation split. Defaults to None.

None
train_split_files List[str]

Create a split called "train" from these files. not used unless referenced by the name "train" in one of the split_name arguments.

None
test_split_files List[str]

Create a split called "test" from these files. not used unless referenced by the name "test" in one of the split_name arguments.

None
valid_split_files List[str]

Create a split called "valid" from these files. not used unless referenced by the name "valid" in one of the split_name arguments.

None
test_split_size float

The size of the test split. If test_split_name is None, creates a test split of this size from the training split.

0.2
valid_split_size float

The size of the validation split. If valid_split_name is None, creates a validation split of this size from the training split.

0.1
random_seed int

The random seed to use for splitting the data. Defaults to 42.

42
batch_size int

The batch size. Defaults to 128.

128
shuffle bool

Whether to shuffle the data. Defaults to True.

True
sampler Optional[Sampler]

The sampler to use. Defaults to None.

None
num_workers int

The number of workers to use for data loading. Defaults to 0.

0
collate_fn Optional[callable]

The function to use for collating data. Defaults to None.

None
pin_memory bool

Whether to pin memory. Defaults to True.

True
persistent_workers bool

Whether to use persistent workers. Defaults to False.

False
cv_num_folds int

The number of cross-validation folds, disables cv when <= 1. Defaults to 1.

1
cv_test_fold_id int

The fold id to use for cross-validation evaluation. Defaults to 0.

0
cv_enable_val_fold bool

Whether to enable a validation fold. Defaults to True.

True
cv_fold_id_col Optional[str]

The column name containing the fold id from a pre-split dataset. Set to None to enable automatic splitting. Defaults to None.

None
cv_val_offset int

the offset applied to cv_test_fold_id to determin val_fold_id

1
Source code in modelgenerator/data/base.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class DataInterface(pl.LightningDataModule, KFoldMixin):
    """Base class for all data modules in this project. Handles the boilerplate of setting up data loaders.

    Note:
        Subclasses must implement the setup method.
        All datasets should return a dictionary of data items.
        To use HF loading, add the HFDatasetLoaderMixin.
        For any task-specific behaviors, implement transformations using `torch.utils.data.Dataset` objects.
        See [MLM](./#modelgenerator.data.MLMDataModule) for an example.

    Args:
        path (str): Path to the dataset, can be one of:
            1. a local path to a data folder
            2. a Huggingface dataset identifier
        config_name (str, optional): Defining the name of the dataset configuration.
          it affects how the dataset is loaded.
        train_split_name (str, optional): The name of the training split. Defaults to "train".
        test_split_name (str, optional): The name of the test split. Defaults to "test".
        valid_split_name (str, optional): The name of the validation split. Defaults to None.
        train_split_files (List[str], optional): Create a split called "train" from these files.
            not used unless referenced by the name "train" in one of the split_name arguments.
        test_split_files (List[str], optional): Create a split called "test" from these files.
            not used unless referenced by the name "test" in one of the split_name arguments.
        valid_split_files (List[str], optional): Create a split called "valid" from these files.
            not used unless referenced by the name "valid" in one of the split_name arguments.
        test_split_size (float, optional): The size of the test split.
           If test_split_name is None, creates a test split of this size from the training split.
        valid_split_size (float, optional): The size of the validation split.
           If valid_split_name is None, creates a validation split of this size from the training split.
        random_seed (int, optional): The random seed to use for splitting the data. Defaults to 42.
        batch_size (int, optional): The batch size. Defaults to 128.
        shuffle (bool, optional): Whether to shuffle the data. Defaults to True.
        sampler (Optional[torch.utils.data.Sampler], optional): The sampler to use. Defaults to None.
        num_workers (int, optional): The number of workers to use for data loading. Defaults to 0.
        collate_fn (Optional[callable], optional): The function to use for collating data. Defaults to None.
        pin_memory (bool, optional): Whether to pin memory. Defaults to True.
        persistent_workers (bool, optional): Whether to use persistent workers. Defaults to False.
        cv_num_folds (int, optional): The number of cross-validation folds, disables cv when <= 1. Defaults to 1.
        cv_test_fold_id (int, optional): The fold id to use for cross-validation evaluation. Defaults to 0.
        cv_enable_val_fold (bool, optional): Whether to enable a validation fold. Defaults to True.
        cv_fold_id_col (Optional[str], optional): The column name containing the fold id from a pre-split dataset. Set to None to enable automatic splitting. Defaults to None.
        cv_val_offset (int, optional): the offset applied to cv_test_fold_id to determin val_fold_id
    """

    def __init__(
        self,
        path: str,
        config_name: Optional[str] = None,
        train_split_name: Optional[str] = "train",
        test_split_name: Optional[str] = "test",
        valid_split_name: Optional[str] = None,
        train_split_files: Optional[List[str]] = None,
        test_split_files: Optional[List[str]] = None,
        valid_split_files: Optional[List[str]] = None,
        test_split_size: float = 0.2,
        valid_split_size: float = 0.1,
        random_seed: int = 42,
        batch_size: int = 128,
        shuffle: bool = True,
        sampler: Optional[torch.utils.data.Sampler] = None,
        num_workers: int = 0,
        collate_fn: Optional[callable] = None,
        pin_memory: bool = True,
        persistent_workers: bool = False,
        # TODO: cv params will be deprecated and will be handled by trainer directly
        cv_num_folds: int = 1,
        cv_test_fold_id: int = 0,
        cv_enable_val_fold: bool = True,
        cv_fold_id_col: Optional[str] = None,
        cv_val_offset: int = 1,
    ):
        super().__init__()
        if os.path.isfile(path):
            raise ValueError(
                "Path must be a directory or a Huggingface dataset repo. "
                "If you want to pass only one file, set the path to the directory "
                "containing the file and set `*_split_files` to `[filename]`."
            )
        self.path = path
        self.config_name = config_name
        self.train_split_name = train_split_name
        self.test_split_name = test_split_name
        self.valid_split_name = valid_split_name
        self.train_split_files = train_split_files
        self.test_split_files = test_split_files
        self.valid_split_files = valid_split_files
        self.test_split_size = test_split_size
        self.valid_split_size = valid_split_size
        self.random_seed = random_seed
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.sampler = sampler
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.cv_num_folds = cv_num_folds
        self.cv_test_fold_id = cv_test_fold_id
        self.cv_enable_val_fold = cv_enable_val_fold
        self.cv_fold_id_col = cv_fold_id_col
        self.cv_val_offset = cv_val_offset

    def setup(self, stage: Optional[str] = None) -> None:
        """Set up the data module. This method should be overridden by subclasses.

        Args:
            stage (Optional[str], optional): training, validation, or test if these need to be setup separately. Defaults to None.
        """
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def train_dataloader(self) -> DataLoader:
        """Get the training data loader

        Returns:
            DataLoader: The training data loader
        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            sampler=self.sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers,
        )

    def val_dataloader(self) -> DataLoader:
        """Get the validation data loader

        Returns:
            DataLoader: The validation data loader
        """
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            sampler=self.sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers,
        )

    def test_dataloader(self) -> DataLoader:
        """Get the test data loader

        Returns:
            DataLoader: The test data loader
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            sampler=self.sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers,
        )

    def predict_dataloader(self) -> DataLoader:
        """Get the dataloader for predictions for the test set

        Returns:
            DataLoader: The predict data loader
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            sampler=self.sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers,
        )

setup(stage=None)

Set up the data module. This method should be overridden by subclasses.

Parameters:

Name Type Description Default
stage Optional[str]

training, validation, or test if these need to be setup separately. Defaults to None.

None
Source code in modelgenerator/data/base.py
264
265
266
267
268
269
270
271
272
def setup(self, stage: Optional[str] = None) -> None:
    """Set up the data module. This method should be overridden by subclasses.

    Args:
        stage (Optional[str], optional): training, validation, or test if these need to be setup separately. Defaults to None.
    """
    self.train_dataset = None
    self.val_dataset = None
    self.test_dataset = None

Useful Mixins

modelgenerator.data.HFDatasetLoaderMixin

Provides methods for loading datasets using the Huggingface datasets library.

Source code in modelgenerator/data/base.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class HFDatasetLoaderMixin:
    """Provides methods for loading datasets using the Huggingface datasets library."""

    def load_dataset(self) -> Tuple[datasets.Dataset]:
        split_names = [
            self.train_split_name,
            self.valid_split_name,
            self.test_split_name,
        ]
        data_files = {}
        if self.train_split_files:
            data_files["train"] = self.train_split_files
        if self.valid_split_files:
            data_files["valid"] = self.valid_split_files
        if self.test_split_files:
            data_files["test"] = self.test_split_files
        splits = ()
        for split_name in split_names:
            if split_name is None:
                splits += (None,)
            else:
                splits += (
                    load_dataset(
                        self.path,
                        data_files=None if not data_files else data_files,
                        name=self.config_name,
                        streaming=False,
                        split=split_name,
                    ),
                )
        return splits

    def load_and_split_dataset(self) -> Tuple[datasets.Dataset]:
        train_dataset, valid_dataset, test_dataset = self.load_dataset()
        if test_dataset is None and self.test_split_size > 0:
            rank_zero_info(
                f"> Randomly split {self.test_split_size} of train for testing, Random seed: {self.random_seed}"
            )
            train_test_split = train_dataset.train_test_split(
                test_size=self.test_split_size, seed=self.random_seed
            )
            train_dataset = train_test_split["train"]
            test_dataset = train_test_split["test"]
        if valid_dataset is None and self.valid_split_size > 0:
            rank_zero_info(
                f"> Randomly split {self.valid_split_size} of train for validation. Random seed: {self.random_seed}"
            )
            train_test_split = train_dataset.train_test_split(
                test_size=self.valid_split_size, seed=self.random_seed
            )
            train_dataset = train_test_split["train"]
            valid_dataset = train_test_split["test"]
        first_non_empty = train_dataset or valid_dataset or test_dataset
        if first_non_empty is None:
            raise ValueError("All splits are empty")
        # return empty datasets instead of None for easier handling
        if train_dataset is None:
            train_dataset = datasets.Dataset.from_dict(
                {k: [] for k in first_non_empty.column_names}
            )
        if valid_dataset is None:
            valid_dataset = datasets.Dataset.from_dict(
                {k: [] for k in first_non_empty.column_names}
            )
        if test_dataset is None:
            test_dataset = datasets.Dataset.from_dict(
                {k: [] for k in first_non_empty.column_names}
            )
        return train_dataset, valid_dataset, test_dataset

modelgenerator.data.KFoldMixin

Provides methods for splitting datasets into k-folds for cross-validation

Source code in modelgenerator/data/base.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class KFoldMixin:
    """Provides methods for splitting datasets into k-folds for cross-validation"""

    def __init__(self):
        self.cv_splits = None

    def get_split_by_fold_id(
        self, train_dataset, val_dataset, test_dataset, fold_id, val_idx_offset=1
    ):
        """Split the dataset into training, validation, and test sets based on the fold id for test"""
        if self.cv_num_folds <= 1:
            return train_dataset, val_dataset, test_dataset
        if len(val_dataset) or len(test_dataset):
            rank_zero_warn(
                "Redundant val or test splits are not expected and will be ignored during training. "
                "Disable this warning by setting {test,val}_split_size=0 and {test,val}_split_name=None"
            )
        if self.cv_fold_id_col is not None:
            splits = self.read_kfold_split(train_dataset)
        else:
            splits = self.generate_kfold_split(len(train_dataset), self.cv_num_folds)
        test_idx = splits[fold_id]
        val_idx = (
            splits[(fold_id + val_idx_offset) % self.cv_num_folds]
            if self.cv_enable_val_fold
            else []
        )
        train_idx = list(set(range(len(train_dataset))) - set(test_idx) - set(val_idx))
        return (
            train_dataset.select(train_idx),
            train_dataset.select(val_idx),
            train_dataset.select(test_idx),
        )

    def generate_kfold_split(
        self, num_samples: int, num_folds: int, shuffle: bool = True
    ):
        """Randomly split n_samples into n_splits folds and return list of fold_idx

        Args:
            num_samples (int): Number of samples in the data.
            num_folds (Optional[int]): Number of folds for cross validation, must be > 1. Defaults to 10.
            shuffle (Optional[bool]): Whether to shuffle the data before splitting into batches. Defaults to True.

        Returns:
            list of list containing indices for each fold
        """
        if self.cv_splits is not None:
            return self.cv_splits
        idx = np.arange(num_samples)
        if shuffle:
            np.random.seed(self.random_seed)
            np.random.shuffle(idx)
        fold_samples = num_samples // num_folds
        kfold_split_idx = []
        for k in range(num_folds - 1):
            kfold_split_idx.append(
                idx[k * fold_samples : (k + 1) * fold_samples].tolist()
            )
        kfold_split_idx.append(idx[(k + 1) * fold_samples :].tolist())
        self.cv_splits = kfold_split_idx
        return kfold_split_idx

    def read_kfold_split(self, dataset: datasets.Dataset):
        """Read the fold ids from a pre-split dataset and return list of fold_idx"""
        fold_ids = sorted(dataset.unique(self.cv_fold_id_col))
        if list(range(self.cv_num_folds)) != fold_ids:
            raise ValueError(f"Fold ids {fold_ids} do not match the expected range")
        kfold_split_idx = []
        for fold_id in fold_ids:
            kfold_split_idx.append(
                np.where(np.array(dataset[self.cv_fold_id_col], dtype=int) == fold_id)[
                    0
                ]
            )
        self.cv_splits = kfold_split_idx
        return kfold_split_idx

generate_kfold_split(num_samples, num_folds, shuffle=True)

Randomly split n_samples into n_splits folds and return list of fold_idx

Parameters:

Name Type Description Default
num_samples int

Number of samples in the data.

required
num_folds Optional[int]

Number of folds for cross validation, must be > 1. Defaults to 10.

required
shuffle Optional[bool]

Whether to shuffle the data before splitting into batches. Defaults to True.

True

Returns:

Type Description

list of list containing indices for each fold

Source code in modelgenerator/data/base.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def generate_kfold_split(
    self, num_samples: int, num_folds: int, shuffle: bool = True
):
    """Randomly split n_samples into n_splits folds and return list of fold_idx

    Args:
        num_samples (int): Number of samples in the data.
        num_folds (Optional[int]): Number of folds for cross validation, must be > 1. Defaults to 10.
        shuffle (Optional[bool]): Whether to shuffle the data before splitting into batches. Defaults to True.

    Returns:
        list of list containing indices for each fold
    """
    if self.cv_splits is not None:
        return self.cv_splits
    idx = np.arange(num_samples)
    if shuffle:
        np.random.seed(self.random_seed)
        np.random.shuffle(idx)
    fold_samples = num_samples // num_folds
    kfold_split_idx = []
    for k in range(num_folds - 1):
        kfold_split_idx.append(
            idx[k * fold_samples : (k + 1) * fold_samples].tolist()
        )
    kfold_split_idx.append(idx[(k + 1) * fold_samples :].tolist())
    self.cv_splits = kfold_split_idx
    return kfold_split_idx

get_split_by_fold_id(train_dataset, val_dataset, test_dataset, fold_id, val_idx_offset=1)

Split the dataset into training, validation, and test sets based on the fold id for test

Source code in modelgenerator/data/base.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_split_by_fold_id(
    self, train_dataset, val_dataset, test_dataset, fold_id, val_idx_offset=1
):
    """Split the dataset into training, validation, and test sets based on the fold id for test"""
    if self.cv_num_folds <= 1:
        return train_dataset, val_dataset, test_dataset
    if len(val_dataset) or len(test_dataset):
        rank_zero_warn(
            "Redundant val or test splits are not expected and will be ignored during training. "
            "Disable this warning by setting {test,val}_split_size=0 and {test,val}_split_name=None"
        )
    if self.cv_fold_id_col is not None:
        splits = self.read_kfold_split(train_dataset)
    else:
        splits = self.generate_kfold_split(len(train_dataset), self.cv_num_folds)
    test_idx = splits[fold_id]
    val_idx = (
        splits[(fold_id + val_idx_offset) % self.cv_num_folds]
        if self.cv_enable_val_fold
        else []
    )
    train_idx = list(set(range(len(train_dataset))) - set(test_idx) - set(val_idx))
    return (
        train_dataset.select(train_idx),
        train_dataset.select(val_idx),
        train_dataset.select(test_idx),
    )

read_kfold_split(dataset)

Read the fold ids from a pre-split dataset and return list of fold_idx

Source code in modelgenerator/data/base.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def read_kfold_split(self, dataset: datasets.Dataset):
    """Read the fold ids from a pre-split dataset and return list of fold_idx"""
    fold_ids = sorted(dataset.unique(self.cv_fold_id_col))
    if list(range(self.cv_num_folds)) != fold_ids:
        raise ValueError(f"Fold ids {fold_ids} do not match the expected range")
    kfold_split_idx = []
    for fold_id in fold_ids:
        kfold_split_idx.append(
            np.where(np.array(dataset[self.cv_fold_id_col], dtype=int) == fold_id)[
                0
            ]
        )
    self.cv_splits = kfold_split_idx
    return kfold_split_idx

Implementing a DataModule

To transform datasets for task-specific behaviors (e.g. masking for masked language modeling), use torch.utils.data.Dataset objects to implement the transformation. Below is an example.

modelgenerator.data.MLMDataModule

Bases: SequenceClassificationDataModule

Data module for continuing pretraining on a masked language modeling task. Inherits from SequenceClassificationDataModule.

Note

Each sample includes a single sequence under key 'sequences' and a single target sequence under key 'target_sequences'

Parameters:

Name Type Description Default
masking_rate float

The masking rate. Defaults to 0.15.

0.15
Source code in modelgenerator/data/data.py
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
class MLMDataModule(SequenceClassificationDataModule):
    """Data module for continuing pretraining on a masked language modeling task. Inherits from SequenceClassificationDataModule.

    Note:
        Each sample includes a single sequence under key 'sequences' and a single target sequence under key 'target_sequences'

    Args:
        masking_rate (float, optional): The masking rate. Defaults to 0.15.
    """

    def __init__(
        self,
        *args,
        masking_rate: float = 0.15,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.masking_rate = masking_rate

    def setup(self, stage: Optional[str] = None):
        super().setup(stage)
        self.train_dataset = MLMDataset(
            dataset=self.train_dataset,
            masking_rate=self.masking_rate,
        )
        self.val_dataset = MLMDataset(
            dataset=self.val_dataset,
            masking_rate=self.masking_rate,
        )
        self.test_dataset = MLMDataset(
            dataset=self.test_dataset,
            masking_rate=self.masking_rate,
        )

setup(stage=None)

Source code in modelgenerator/data/data.py
908
909
910
911
912
913
914
915
916
917
918
919
920
921
def setup(self, stage: Optional[str] = None):
    super().setup(stage)
    self.train_dataset = MLMDataset(
        dataset=self.train_dataset,
        masking_rate=self.masking_rate,
    )
    self.val_dataset = MLMDataset(
        dataset=self.val_dataset,
        masking_rate=self.masking_rate,
    )
    self.test_dataset = MLMDataset(
        dataset=self.test_dataset,
        masking_rate=self.masking_rate,
    )

modelgenerator.data.MLMDataset

Bases: Dataset

Masked Language Modeling Dataset

Note

Each sample includes a single sequence under key 'sequences' and a single target sequence under key 'target_sequences'

Parameters:

Name Type Description Default
dataset Dataset

The dataset to mask

required
masking_rate float

The masking rate

required
Source code in modelgenerator/data/data.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class MLMDataset(Dataset):
    """Masked Language Modeling Dataset

    Note:
        Each sample includes a single sequence under key 'sequences' and a single target sequence under key 'target_sequences'

    Args:
        dataset (Dataset): The dataset to mask
        masking_rate (float): The masking rate
    """

    def __init__(self, dataset, masking_rate):
        self.dataset = dataset
        self.masking_rate = masking_rate

    def get_masked_sample(self, seq_target, masking_rate):
        """
        Mask a sequence with a given masking rate
        """
        num_mask_tokens = max(1, int(len(seq_target) * masking_rate))
        perm = torch.randperm(len(seq_target))
        input_mask_indices = perm[:num_mask_tokens]
        # Mask the input sequence
        seq_input = replace_characters_at_indices(
            s=seq_target, indices=input_mask_indices, replacement_char="[MASK]"
        )
        return seq_input

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data_dict = self.dataset[idx]
        seq_target = data_dict["sequences"]
        seq_input = self.get_masked_sample(seq_target, self.masking_rate)
        data_dict.update({"sequences": seq_input, "target_sequences": seq_target})
        # prepend __empty__ to help pytorch lightning infer batch size
        return {"__empty__": 0, **data_dict}

get_masked_sample(seq_target, masking_rate)

Mask a sequence with a given masking rate

Source code in modelgenerator/data/data.py
193
194
195
196
197
198
199
200
201
202
203
204
def get_masked_sample(self, seq_target, masking_rate):
    """
    Mask a sequence with a given masking rate
    """
    num_mask_tokens = max(1, int(len(seq_target) * masking_rate))
    perm = torch.randperm(len(seq_target))
    input_mask_indices = perm[:num_mask_tokens]
    # Mask the input sequence
    seq_input = replace_characters_at_indices(
        s=seq_target, indices=input_mask_indices, replacement_char="[MASK]"
    )
    return seq_input