scrna5/6 Jupyter Notebook lamindata

Train an ML model on a collection#

In the previous tutorial, we loaded an entire collection into memory to perform a simple analysis.

Here, we鈥檒l iterate over the artifacts within the collection to train an ML model.

import lamindb as ln
import anndata as ad
import numpy as np
馃挕 lamindb instance: testuser1/test-scrna
ln.track()
馃挕 notebook imports: anndata==0.9.2 lamindb==0.65.0 numpy==1.26.3 torch==2.1.2
馃挕 saved: Transform(uid='Qr1kIHvK506r5zKv', name='Train an ML model on a collection', short_name='scrna5', version='1', type=notebook, updated_at=2024-01-03 01:34:36 UTC, created_by_id=1)
馃挕 saved: Run(uid='1ByLthAeN8ZzwfslGde8', run_at=2024-01-03 01:34:36 UTC, transform_id=5, created_by_id=1)

Preprocessing#

Let us get our collection:

collection_v2 = ln.Collection.filter(
    name="My versioned scRNA-seq collection", version="2"
).one()
collection_v2
Collection(uid='dXqqPuhgFCXaWOCKJfjb', name='My versioned scRNA-seq collection', version='2', hash='BOAf0T5UbN_iOe3fQDyq', visibility=1, updated_at=2024-01-03 01:34:14 UTC, transform_id=2, run_id=2, created_by_id=1)

PyTorch DataLoader#

If you need to train your model on a list of artifacts, you can use mapped() with the PyTorch DataLoader.

It only loads batches into memory and thus allows to work with very large collections.

from torch.utils.data import DataLoader, WeightedRandomSampler

Let us create a MappedCollection object from Collection.

Under-the-hood, it performs a virtual inner join of the variables of the underlying AnnData objects.

ds_mapped = collection_v2.mapped(label_keys=["cell_type"])

The intersected variable names can be accessed:

len(ds_mapped.var_joint)
749

This is compatible with pytorch DataLoader because it implements __getitem__ over a list of backed AnnData objects.

ds_mapped[5]
Hide code cell output
[array([0.       , 0.       , 0.       , 1.513164 , 0.       , 0.       ,
        0.       , 0.       , 1.0190415, 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 3.4371805,
        0.       , 0.       , 0.       , 1.513164 , 0.       , 0.       ,
        1.513164 , 1.0190415, 0.       , 0.       , 1.0190415, 1.0190415,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 1.0190415, 0.       , 0.       , 0.       ,
        0.       , 0.       , 2.0896583, 0.       , 1.513164 , 1.8423915,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 1.0190415, 0.       , 0.       , 0.       , 0.       ,
        1.0190415, 0.       , 0.       , 0.       , 0.       , 1.8423915,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 1.0190415, 0.       , 0.       ,
        1.0190415, 0.       , 1.8423915, 0.       , 1.0190415, 0.       ,
        0.       , 0.       , 0.       , 1.0190415, 0.       , 0.       ,
        0.       , 1.513164 , 0.       , 0.       , 2.2877445, 2.4530056,
        0.       , 0.       , 1.0190415, 0.       , 0.       , 0.       ,
        0.       , 1.0190415, 0.       , 0.       , 0.       , 1.0190415,
        1.0190415, 1.0190415, 1.0190415, 1.8423915, 0.       , 0.       ,
        2.0896583, 1.0190415, 1.8423915, 0.       , 0.       , 0.       ,
        0.       , 0.       , 1.513164 , 0.       , 1.513164 , 1.0190415,
        0.       , 0.       , 1.0190415, 0.       , 0.       , 0.       ,
        1.0190415, 0.       , 1.0190415, 0.       , 0.       , 1.0190415,
        0.       , 1.0190415, 1.513164 , 0.       , 1.0190415, 0.       ,
        0.       , 1.0190415, 1.0190415, 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        2.2877445, 0.       , 1.513164 , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        1.0190415, 0.       , 1.513164 , 1.0190415, 0.       , 0.       ,
        1.0190415, 1.8423915, 2.0896583, 1.0190415, 0.       , 1.513164 ,
        1.0190415, 1.0190415, 1.0190415, 0.       , 1.0190415, 0.       ,
        0.       , 0.       , 0.       , 0.       , 1.0190415, 1.513164 ,
        0.       , 0.       , 1.8423915, 0.       , 1.0190415, 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 5.1414976, 0.       , 2.4530056, 1.0190415, 1.513164 ,
        0.       , 0.       , 1.0190415, 0.       , 2.4530056, 1.0190415,
        0.       , 0.       , 1.0190415, 0.       , 1.513164 , 0.       ,
        1.8423915, 0.       , 0.       , 1.513164 , 2.9288113, 0.       ,
        0.       , 0.       , 2.2877445, 1.0190415, 3.8508706, 1.513164 ,
        3.2498913, 2.4530056, 2.2877445, 0.       , 1.8423915, 0.       ,
        2.7189443, 2.9288113, 0.       , 1.0190415, 0.       , 1.0190415,
        0.       , 1.0190415, 1.0190415, 0.       , 2.2877445, 1.0190415,
        0.       , 1.8423915, 1.0190415, 1.0190415, 1.0190415, 0.       ,
        0.       , 0.       , 1.513164 , 1.0190415, 1.8423915, 0.       ,
        0.       , 1.0190415, 1.8423915, 1.8423915, 0.       , 0.       ,
        1.513164 , 0.       , 1.0190415, 0.       , 2.0896583, 1.0190415,
        0.       , 0.       , 1.0190415, 0.       , 0.       , 1.0190415,
        1.0190415, 2.5947895, 1.0190415, 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 1.513164 , 0.       , 1.0190415,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 2.0896583, 0.       , 0.       , 0.       ,
        0.       , 1.0190415, 0.       , 1.0190415, 0.       , 0.       ,
        0.       , 1.0190415, 0.       , 0.       , 0.       , 0.       ,
        0.       , 1.0190415, 0.       , 1.513164 , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 1.513164 , 0.       ,
        1.0190415, 1.513164 , 0.       , 1.0190415, 0.       , 1.513164 ,
        0.       , 1.0190415, 0.       , 1.0190415, 0.       , 0.       ,
        1.0190415, 0.       , 0.       , 0.       , 3.4371805, 0.       ,
        0.       , 0.       , 0.       , 2.0896583, 0.       , 0.       ,
        1.0190415, 0.       , 1.0190415, 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 1.513164 , 1.0190415,
        3.1787617, 0.       , 1.0190415, 1.0190415, 1.0190415, 0.       ,
        0.       , 1.0190415, 0.       , 1.0190415, 0.       , 0.       ,
        0.       , 0.       , 1.513164 , 1.513164 , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 1.513164 , 1.513164 , 1.0190415, 0.       ,
        0.       , 0.       , 2.4530056, 0.       , 1.513164 , 0.       ,
        0.       , 0.       , 0.       , 1.0190415, 1.0190415, 0.       ,
        1.0190415, 0.       , 1.513164 , 0.       , 0.       , 1.0190415,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 1.0190415,
        0.       , 1.0190415, 0.       , 1.513164 , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 1.513164 , 0.       , 0.       , 0.       , 1.0190415,
        1.0190415, 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 2.0896583, 0.       , 2.4530056, 1.0190415,
        0.       , 0.       , 0.       , 1.513164 , 0.       , 1.513164 ,
        1.0190415, 0.       , 0.       , 0.       , 0.       , 1.0190415,
        1.0190415, 0.       , 0.       , 0.       , 1.0190415, 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 2.2877445, 0.       , 2.5947895,
        1.0190415, 0.       , 0.       , 0.       , 0.       , 1.0190415,
        2.9288113, 1.0190415, 2.4530056, 0.       , 0.       , 1.0190415,
        2.0896583, 0.       , 0.       , 0.       , 0.       , 0.       ,
        1.513164 , 0.       , 1.0190415, 0.       , 1.0190415, 0.       ,
        0.       , 1.0190415, 0.       , 1.0190415, 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 1.0190415, 1.513164 ,
        0.       , 0.       , 0.       , 0.       , 1.0190415, 0.       ,
        0.       , 0.       , 0.       , 1.0190415, 1.513164 , 1.0190415,
        0.       , 0.       , 0.       , 2.0896583, 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 1.8423915, 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 1.513164 , 0.       , 0.       ,
        0.       , 1.8423915, 0.       , 0.       , 1.8423915, 0.       ,
        0.       , 0.       , 0.       , 0.       , 1.0190415, 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 1.513164 ,
        0.       , 0.       , 1.513164 , 0.       , 2.4530056, 0.       ,
        1.513164 , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 1.513164 , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 1.0190415, 1.0190415,
        1.513164 , 0.       , 0.       , 0.       , 1.0190415, 1.0190415,
        1.0190415, 0.       , 0.       , 0.       , 0.       , 0.       ,
        1.0190415, 1.0190415, 0.       , 0.       , 0.       , 1.513164 ,
        0.       , 0.       , 0.       , 0.       , 0.       , 1.0190415,
        1.8423915, 2.0896583, 0.       , 0.       , 0.       , 1.0190415,
        0.       , 2.9288113, 0.       , 1.0190415, 0.       , 1.0190415,
        1.0190415, 0.       , 1.8423915, 0.       , 0.       , 0.       ,
        0.       , 0.       , 1.0190415, 0.       , 0.       , 2.0896583,
        0.       , 0.       , 0.       , 0.       , 0.       , 3.3162956,
        0.       , 2.4530056, 0.       , 1.0190415, 0.       , 1.0190415,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        1.0190415, 0.       , 0.       , 0.       , 0.       , 1.0190415,
        0.       , 0.       , 0.       , 1.0190415, 0.       , 0.       ,
        1.0190415, 1.513164 , 1.8423915, 0.       , 2.4530056, 1.0190415,
        1.0190415, 0.       , 1.0190415, 1.0190415, 0.       , 1.513164 ,
        0.       , 1.0190415, 0.       , 0.       , 0.       , 1.0190415,
        1.0190415, 0.       , 0.       , 1.0190415, 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        1.0190415, 0.       , 2.0896583, 0.       , 1.0190415, 1.0190415,
        0.       , 0.       , 0.       , 1.513164 , 0.       , 0.       ,
        0.       , 1.513164 , 1.0190415, 1.0190415, 0.       , 0.       ,
        0.       , 0.       , 1.0190415, 1.0190415, 1.0190415, 1.513164 ,
        0.       , 1.0190415, 0.       , 0.       , 0.       , 4.5521755,
        1.0190415, 0.       , 1.0190415, 0.       , 4.14264  ],
       dtype=float32),
 4]

The labels are encoded into integers.

ds_mapped.encoders
Hide code cell output
[{'lymphocyte': 0,
  'T follicular helper cell': 1,
  'CD4-positive helper T cell': 2,
  'gamma-delta T cell': 3,
  'memory B cell': 4,
  'non-classical monocyte': 5,
  'alpha-beta T cell': 6,
  'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 7,
  'CD4-positive, alpha-beta T cell': 8,
  'plasmablast': 9,
  'plasmacytoid dendritic cell': 10,
  'plasma cell': 11,
  'effector memory CD4-positive, alpha-beta T cell': 12,
  'naive B cell': 13,
  'dendritic cell': 14,
  'animal cell': 15,
  'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 16,
  'megakaryocyte': 17,
  'mast cell': 18,
  'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 19,
  'regulatory T cell': 20,
  'mucosal invariant T cell': 21,
  'CD38-positive naive B cell': 22,
  'group 3 innate lymphoid cell': 23,
  'dendritic cell, human': 24,
  'CD16-positive, CD56-dim natural killer cell, human': 25,
  'progenitor cell': 26,
  'germinal center B cell': 27,
  'classical monocyte': 28,
  'conventional dendritic cell': 29,
  'alveolar macrophage': 30,
  'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 31,
  'cytotoxic T cell': 32,
  'B cell, CD19-positive': 33,
  'CD16-negative, CD56-bright natural killer cell, human': 34,
  'macrophage': 35,
  'CD14-positive, CD16-negative classical monocyte': 36,
  'CD8-positive, alpha-beta memory T cell': 37,
  'naive thymus-derived CD8-positive, alpha-beta T cell': 38,
  'naive thymus-derived CD4-positive, alpha-beta T cell': 39}]

Let us use a weighted sampler:

# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
    weights=ds_mapped.get_label_weights("cell_type"), num_samples=len(ds_mapped)
)
dl = DataLoader(ds_mapped, batch_size=128, sampler=sampler)

We can now iterate through the data loader:

for batch in dl:
    pass

Close the connections in MappedCollection:

ds_mapped.close()
In practice, use a context manager
with ds_train.mapped(label_keys=["cell_type"]) as ds_mapped:
    sampler = WeightedRandomSampler(
        weights=ds_mapped.get_label_weights("cell_type"), num_samples=len(ds_mapped)
    )
    dl = DataLoader(ds_mapped, batch_size=128, sampler=sampler)
    for batch in dl:
        pass