Files
openpipeline/src/annotate/onclass/script.py
CI cd0af18851 Build branch fix-integration-tests with version dev (2dbe3b72)
Build pipeline: vsh-ci-dev-k8tz4

Source commit: 2dbe3b7231

Source message: Fix pointers to test resources
2024-10-17 17:56:12 +00:00

196 lines
7.6 KiB
Python

import sys
import logging
import mudata as mu
import anndata as ad
import re
import numpy as np
from OnClass.OnClassModel import OnClassModel
import obonet
from typing import Dict, Tuple
from tqdm import tqdm
## VIASH START
par = {
"input": "resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_filtered_feature_bc_matrix.h5mu",
"output": "output.h5mu",
"modality": "rna",
"reference": "resources_test/annotation_test_data/TS_Blood_filtered.h5mu",
"model": None,
"reference_obs_targets": "cell_ontology_class",
"input_layer": None,
"reference_layer": None,
"max_iter": 100,
"output_obs_predictions": None,
"output_obs_probability": None,
"cl_nlp_emb_file": "resources_test/annotation_test_data/ontology/cl.ontology.nlp.emb",
"cl_ontology_file": "resources_test/annotation_test_data/ontology/cl.ontology",
"cl_obo_file": "resources_test/annotation_test_data/ontology/cl.obo",
"output_compression": "gzip"
}
meta = {"resources_dir": "src/annotate/onclass"}
## VIASH END
sys.path.append(meta["resources_dir"])
# START TEMPORARY WORKAROUND setup_logger
def setup_logger():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler(sys.stdout)
logFormatter = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s")
console_handler.setFormatter(logFormatter)
logger.addHandler(console_handler)
return logger
# END TEMPORARY WORKAROUND setup_logger
logger = setup_logger()
def map_celltype_to_ontology_id(cl_obo_file: str) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
Map cell type names to ontology IDs and vice versa.
Parameters
----------
cl_obo_file : str
Path to the cell ontology file.
Returns
-------
Tuple[Dict[str, str], Dict[str, str]]
A tuple of two dictionaries. The first dictionary maps cell ontology IDs to cell type names.
The second dictionary maps cell type names to cell ontology IDs.
"""
graph = obonet.read_obo(cl_obo_file)
cl_id_to_name = {id_: data.get("name") for id_, data in graph.nodes(data=True)}
cl_id_to_name = {k: v for k, v in cl_id_to_name.items() if v is not None}
name_to_cl_id = {v: k for k, v in cl_id_to_name.items()}
return cl_id_to_name, name_to_cl_id
def predict_input_data(model: OnClassModel,
input_matrix: np.array,
input_modality: ad.AnnData,
id_to_name: dict,
obs_prediction: str,
obs_probability: str) -> ad.AnnData:
"""
Predict cell types for input data and save results to Anndata obj.
Parameters
----------
model : OnClassModel
The OnClass model.
input_matrix : np.array
The input data matrix.
input_modality : ad.AnnData
The input data Anndata object.
id_to_name : dict
Dictionary mapping cell ontology IDs to cell type names.
obs_prediction : str
The obs key for the predicted cell type.
obs_probability : str
The obs key for the predicted cell type probability.
Returns
-------
ad.AnnData
The input data Anndata object with the predicted cell types saved in obs.
"""
corr_test_feature = model.ProcessTestFeature(
test_feature=input_matrix,
test_genes=input_modality.var_names,
log_transform=False,
)
onclass_pred = model.Predict(corr_test_feature, use_normalize=False, refine=True, unseen_ratio=-1.0)
pred_label = [model.i2co[ind] for ind in onclass_pred[2]]
pred_cell_type_label = [id_to_name[id] for id in pred_label]
input_modality.obs[obs_prediction] = pred_cell_type_label
input_modality.obs[obs_probability] = np.max(onclass_pred[1], axis=1) / onclass_pred[1].sum(1)
return input_modality
def set_var_index(adata, var_name):
adata.var.index = [re.sub("\\.[0-9]+$", "", s) for s in adata.var[var_name]]
return adata
def main():
if (not par["model"] and not par["reference"]) or (par["model"] and par["reference"]):
raise ValueError("Make sure to provide either 'model' or 'reference', but not both.")
logger.info("Reading input data")
input_mudata = mu.read_h5mu(par["input"])
input_modality = input_mudata.mod[par["modality"]].copy()
# Set var names to the desired gene name format (gene synbol, ensembl id, etc.)
input_modality = set_var_index(input_modality, par["var_query_gene_names"]) if par["var_query_gene_names"] else input_modality
input_matrix = input_modality.layers[par["input_layer"]].toarray() if par["input_layer"] else input_modality.X.toarray()
id_to_name, name_to_id = map_celltype_to_ontology_id(par["cl_obo_file"])
if par["model"]:
logger.info("Predicting cell types using pre-trained model")
model = OnClassModel(cell_type_nlp_emb_file=par["cl_nlp_emb_file"],
cell_type_network_file=par["cl_ontology_file"])
model.BuildModel(use_pretrain=par["model"], ngene=None)
elif par["reference"]:
logger.info("Reading reference data")
model = OnClassModel(cell_type_nlp_emb_file=par["cl_nlp_emb_file"],
cell_type_network_file=par["cl_ontology_file"])
reference_mudata = mu.read_h5mu(par["reference"])
reference_modality = reference_mudata.mod[par["modality"]].copy()
reference_modality.var["gene_symbol"] = list(reference_modality.var.index)
reference_modality.var.index = [re.sub("\\.[0-9]+$", "", s) for s in reference_modality.var["ensemblid"]]
logger.info("Detecting common vars based on ensembl ids")
common_ens_ids = list(set(reference_modality.var.index).intersection(set(input_modality.var.index)))
logger.info(" reference n_vars: %i", reference_modality.n_vars)
logger.info(" input n_vars: %i", input_modality.n_vars)
logger.info(" intersect n_vars: %i", len(common_ens_ids))
assert len(common_ens_ids) >= 100, "The intersection of genes is too small."
reference_matrix = reference_modality.layers[par["reference_layer"]].toarray() if par["reference_layer"] else reference_modality.X.toarray()
logger.info("Training a model from reference...")
labels = reference_modality.obs[par["reference_obs_target"]].tolist()
labels_cl = [name_to_id[label] for label in labels]
_ = model.EmbedCellTypes(labels_cl)
(
corr_train_feature,
_,
corr_train_genes,
_,
) = model.ProcessTrainFeature(
train_feature=reference_matrix,
train_label=labels_cl,
train_genes=reference_modality.var_names,
test_feature=input_matrix,
test_genes=input_modality.var_names,
log_transform=False,
)
model.BuildModel(ngene=len(corr_train_genes))
model.Train(corr_train_feature,
labels_cl,
max_iter=par["max_iter"])
logger.info(f"Predicting cell types")
input_modality = predict_input_data(model,
input_matrix,
input_modality,
id_to_name,
par["output_obs_predictions"],
par["output_obs_probability"])
logger.info("Writing output data")
input_mudata.mod[par["modality"]] = input_modality
input_mudata.write_h5mu(par["output"], compression=par["output_compression"])
if __name__ == "__main__":
main()