Build pipeline: vsh-ci-dev-k8tz4
Source commit: 2dbe3b7231
Source message: Fix pointers to test resources
142 lines
5.9 KiB
Python
142 lines
5.9 KiB
Python
import sys
|
|
import os
|
|
import pytest
|
|
import re
|
|
import mudata as mu
|
|
import anndata as ad
|
|
from openpipelinetestutils.asserters import assert_annotation_objects_equal
|
|
import scvi
|
|
import os
|
|
## VIASH START
|
|
meta = {
|
|
"resources_dir": "resources_test"
|
|
}
|
|
## VIASH END
|
|
|
|
input_file = f"{meta['resources_dir']}/pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu"
|
|
reference_file = f"{meta['resources_dir']}/annotation_test_data/TS_Blood_filtered.h5mu"
|
|
|
|
@pytest.fixture
|
|
def create_scvi_model(random_path, tmp_path):
|
|
def wrapper(input_file, reference_file):
|
|
input_data = mu.read_h5mu(input_file)
|
|
input_modality = input_data.mod["rna"]
|
|
reference_data = mu.read_h5mu(reference_file)
|
|
reference_modality = reference_data.mod["rna"]
|
|
|
|
reference_data.var["gene_symbol"] = list(reference_data.var.index)
|
|
reference_data.var.index = [re.sub("\\.[0-9]+$", "", s) for s in reference_data.var["ensemblid"]]
|
|
reference_modality.var["gene_symbol"] = list(reference_modality.var.index)
|
|
reference_modality.var.index = [re.sub("\\.[0-9]+$", "", s) for s in reference_modality.var["ensemblid"]]
|
|
|
|
common_ens_ids = list(set(reference_modality.var.index).intersection(set(input_modality.var.index)))
|
|
|
|
reference = reference_modality[:, common_ens_ids].copy()
|
|
query = input_modality[:, common_ens_ids].copy()
|
|
|
|
scvi.model.SCVI.setup_anndata(reference,
|
|
labels_key="cell_ontology_class"
|
|
)
|
|
|
|
scvi_model = scvi.model.SCVI(
|
|
reference,
|
|
use_layer_norm="both",
|
|
use_batch_norm="none",
|
|
encode_covariates=True,
|
|
dropout_rate=0.2,
|
|
n_layers=1,
|
|
)
|
|
scvi_model.train(max_epochs=10)
|
|
|
|
input_data.mod["rna"] = query
|
|
reference_data.mod["rna"] = reference
|
|
|
|
input_data_file = random_path(extension="h5mu")
|
|
reference_file = random_path(extension="h5mu")
|
|
scvi_model_file = tmp_path
|
|
|
|
input_data.write_h5mu(input_data_file)
|
|
reference_data.write_h5mu(reference_file)
|
|
scvi_model.save(scvi_model_file, overwrite=True)
|
|
|
|
return scvi_model_file, input_data_file, reference_file
|
|
return wrapper
|
|
|
|
def test_simple_execution(run_component, random_h5mu_path, create_scvi_model):
|
|
scvi_model_file, input_file_scvi, reference_file_scvi = create_scvi_model(input_file, reference_file)
|
|
output_file = random_h5mu_path()
|
|
|
|
run_component([
|
|
"--input", input_file_scvi,
|
|
"--reference", reference_file_scvi,
|
|
"--scvi_reference_model", scvi_model_file,
|
|
"--reference_obs_label", "cell_ontology_class",
|
|
"--reference_max_epochs", "10",
|
|
"--query_max_epochs", "10",
|
|
"--output", output_file
|
|
])
|
|
|
|
assert os.path.exists(output_file), "Output file does not exist"
|
|
|
|
input_mudata = mu.read_h5mu(input_file_scvi)
|
|
output_mudata = mu.read_h5mu(output_file)
|
|
|
|
assert input_mudata.mod["rna"].n_obs == output_mudata.mod["rna"].n_obs, f"Number of observations changed"
|
|
assert input_mudata.mod["rna"].n_vars == output_mudata.mod["rna"].n_vars, f"Number of variables changed"
|
|
assert "scanvi_embedding" in output_mudata.mod["rna"].obsm.keys(), "Latent representation not added"
|
|
assert "scanvi_pred" in output_mudata.mod["rna"].obs.keys(), "Predictions not added"
|
|
assert "scanvi_probability" in output_mudata.mod["rna"].obs.keys(), "Probabilities not added"
|
|
|
|
assert_annotation_objects_equal(input_mudata.mod["prot"],
|
|
output_mudata.mod["prot"])
|
|
|
|
def test_multiple_arguments(run_component, random_h5mu_path, create_scvi_model, tmp_path):
|
|
scvi_model_file, input_file_scvi, reference_file_scvi = create_scvi_model(input_file, reference_file)
|
|
output_file = random_h5mu_path()
|
|
|
|
run_component([
|
|
"--input", input_file_scvi,
|
|
"--reference", reference_file_scvi,
|
|
"--scvi_reference_model", scvi_model_file,
|
|
"--reference_obs_label", "cell_ontology_class",
|
|
"--output", output_file,
|
|
"--reference_max_epochs", "10",
|
|
"--reference_reduce_lr_on_plateau", "True",
|
|
"--reference_lr_patience", "5",
|
|
"--reference_lr_factor", "0.5",
|
|
"--reference_train_size", "0.8",
|
|
"--reference_early_stopping", "True",
|
|
"--reference_early_stopping_patience", "5",
|
|
"--reference_early_stopping_min_delta", "0.01",
|
|
"--query_max_epochs", "10",
|
|
"--query_reduce_lr_on_plateau", "True",
|
|
"--query_lr_patience", "5",
|
|
"--query_lr_factor", "0.5",
|
|
"--query_train_size", "0.8",
|
|
"--query_early_stopping", "True",
|
|
"--query_early_stopping_patience", "5",
|
|
"--query_early_stopping_min_delta", "0.01",
|
|
"--output_obs_predictions", "scanvi_pred",
|
|
"--output_obs_probabilities", "scanvi_probabilitity",
|
|
"--output_compression", "gzip",
|
|
"--output_model", tmp_path
|
|
])
|
|
|
|
assert os.path.exists(output_file), "Output file does not exist"
|
|
assert os.path.exists(tmp_path / "model.pt"), "Model file does not exist"
|
|
|
|
input_mudata = mu.read_h5mu(input_file_scvi)
|
|
output_mudata = mu.read_h5mu(output_file)
|
|
|
|
assert input_mudata.mod["rna"].n_obs == output_mudata.mod["rna"].n_obs, f"Number of observations changed"
|
|
assert input_mudata.mod["rna"].n_vars == output_mudata.mod["rna"].n_vars, f"Number of variables changed"
|
|
assert "scanvi_embedding" in output_mudata.mod["rna"].obsm.keys(), "Latent representation not added"
|
|
assert "scanvi_pred" in output_mudata.mod["rna"].obs.keys(), "Predictions not added"
|
|
assert "scanvi_probability" in output_mudata.mod["rna"].obs.keys(), "Probabilities not added"
|
|
|
|
|
|
assert_annotation_objects_equal(input_mudata.mod["prot"],
|
|
output_mudata.mod["prot"])
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(pytest.main([__file__])) |