Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer | |
| from transformers.pipelines import Pipeline | |
| try: | |
| import requests | |
| except Exception: | |
| requests = None | |
| try: | |
| import matplotlib.pyplot as plt | |
| except Exception: | |
| plt = None | |
| # --------------------------------------------------------------------- | |
| # Assembly <-> species mapping | |
| # --------------------------------------------------------------------- | |
| ASSEMBLY_TO_SPECIES = { | |
| "hg38": "human", | |
| "mm10": "mouse", | |
| "dm6": "drosophila_melanogaster", | |
| "TAIR10": "arabidopsis_thaliana", | |
| "Zm-B73-REFERENCE-NAM-5.0": "zea_mays", | |
| "IRGSP-1.0": "oryza_sativa", | |
| "Glycine_max_v2.1": "glycine_max", | |
| "IWGSC": "triticum_aestivum", | |
| "Gossypium_hirsutum_v2.1": "gossypium_hirsutum", | |
| "AmpOce1": "amphiprion_ocellaris", | |
| "Bison_UMD1": "bison_bison_bison", | |
| "ChiLan1": "chinchilla_lanigera", | |
| "Felis_catus_9": "felis_catus", | |
| "GRCz11": "danio_rerio", | |
| "KH": "ciona_intestinalis", | |
| "Mnem_1": "macaca_nemestrina", | |
| "ROS_Cfam_1": "canis_lupus_familiaris", | |
| "SCA1": "serinus_canaria", | |
| "TETRAODON8": "tetraodon_nigroviridis", | |
| "WBcel235": "caenorhabditis_elegans", | |
| "bGalGal1": "gallus_gallus", | |
| "fSalTru1": "salmo_trutta", | |
| "gorGor4": "gorilla_gorilla", | |
| "mRatBN7": "rattus_norvegicus", | |
| } | |
| SPECIES_TO_ASSEMBLY = {v: k for k, v in ASSEMBLY_TO_SPECIES.items()} | |
| # --------------------------------------------------------------------- | |
| # Species that support coordinate-based sequence fetching | |
| # --------------------------------------------------------------------- | |
| # List of species that can fetch DNA sequences from genomic coordinates via API. | |
| # Species not in this list can still be used but require direct DNA sequence input. | |
| SPECIES_WITH_COORDINATE_SUPPORT = { | |
| "human", # hg38 - UCSC API | |
| "mouse", # mm10 - UCSC API | |
| "drosophila_melanogaster", # dm6 - UCSC API | |
| "arabidopsis_thaliana", # TAIR10 - UCSC hub API | |
| "gorilla_gorilla", # gorGor4 - UCSC API | |
| # Add more species as API URLs are configured | |
| } | |
| # --------------------------------------------------------------------- | |
| # Assembly -> API URL template mapping | |
| # --------------------------------------------------------------------- | |
| # Default API URL template (UCSC format) that works for most species | |
| DEFAULT_API_URL_TEMPLATE = "https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}" # noqa: E501 | |
| # for species with different format, add the assembly name to the mapping | |
| # The template should use {chrom}, {start}, and {end} as placeholders. | |
| ASSEMBLY_TO_API_URL_TEMPLATE = { | |
| # Arabidopsis thaliana (TAIR10) - uses hub URL format | |
| "TAIR10": "https://api.genome.ucsc.edu/getData/sequence?hubUrl=http://genome.ucsc.edu/goldenPath/help/examples/hubExamples/hubAssembly/plantAraTha1/hub.txt;genome=araTha1;chrom={chrom};start={start};end={end}", # noqa: E501 | |
| } | |
| # BED element to color mapping (shared between pipeline and app) | |
| BED_ELEMENT_COLORS = { | |
| "protein coding gene": "#E74C3C", # Red | |
| "lncRNA": "#2ECC71", # Green | |
| "exon": "#9B59B6", # Purple | |
| "intron": "#F39C12", # Orange | |
| "splice_donor": "#1ABC9C", # Teal | |
| "splice_acceptor": "#E67E22", # Dark orange | |
| "CTCF-bound": "#3498DB", # Light blue | |
| "polyA_signal": "#95A5A6", # Gray | |
| "enhancer Tissue specific": "#D35400", # Dark red | |
| "enhancer Tissue invariant": "#16A085", # Dark teal | |
| "promoter Tissue specific": "#C0392B", # Dark red 2 | |
| "promoter Tissue invariant": "#27AE60", # Dark green | |
| "5UTR+": "#8E44AD", # Dark purple | |
| "5UTR-": "#D68910", # Dark orange 2 | |
| "3UTR+": "#138D75", # Dark teal 2 | |
| "3UTR-": "#2874A6", # Dark blue | |
| "skipped exon": "#7D3C98", # Purple 2 | |
| "always on exon": "#A93226", # Red 2 | |
| "start codon": "#196F3D", # Green 2 | |
| "stop codon": "#B9770E", # Brown | |
| "ORF": "#1F618D", # Blue 2 | |
| } | |
| def _filter_bed_elements_by_species( | |
| bed_element_names: list[str], species: str | |
| ) -> list[str]: | |
| """ | |
| Filter BED element names based on species-specific training data availability. | |
| Rules: | |
| - Human: all tracks | |
| - Mouse: only polyA_signal | |
| - Other species: everything except promoter, enhancer, ctcf, lncrna | |
| Parameters | |
| ---------- | |
| bed_element_names : list[str] | |
| Full list of BED element names from the model config | |
| species : str | |
| Species name (e.g., "human", "mouse", "drosophila_melanogaster") | |
| Returns | |
| ------- | |
| list[str] | |
| Filtered list of BED element names available for this species | |
| """ | |
| if not bed_element_names: | |
| return [] | |
| # Elements to exclude for "other species" (everything except human and mouse) | |
| excluded_for_other_species = { | |
| "promoter Tissue specific", | |
| "promoter Tissue invariant", | |
| "enhancer Tissue specific", | |
| "enhancer Tissue invariant", | |
| "CTCF-bound", | |
| "lncRNA", | |
| } | |
| # Normalize element names (handle both with/without underscores/spaces) | |
| normalized_excluded = set() | |
| for elem in excluded_for_other_species: | |
| normalized_excluded.add(elem) | |
| normalized_excluded.add(elem.replace(" ", "_")) | |
| if species == "human": | |
| # Human: all tracks | |
| return list(bed_element_names) | |
| else: | |
| # Other species: everything except promoter, enhancer, ctcf, lncrna | |
| # Normalize element names for comparison (handle spaces, underscores, case) | |
| normalized_bed_names = { | |
| elem.lower().replace("_", " "): elem | |
| for elem in bed_element_names | |
| } | |
| normalized_excluded_lower = { | |
| elem.lower().replace("_", " ") | |
| for elem in excluded_for_other_species | |
| } | |
| # Also check for keywords in element names | |
| excluded_keywords = ["promoter", "enhancer", "ctcf", "lnc"] | |
| filtered_normalized = [ | |
| norm_name | |
| for norm_name, orig_elem in normalized_bed_names.items() | |
| if norm_name not in normalized_excluded_lower | |
| and not any(keyword in norm_name for keyword in excluded_keywords) | |
| ] | |
| # Return original element names (preserving original format) | |
| return [ | |
| normalized_bed_names[norm_name] for norm_name in filtered_normalized | |
| ] | |
| def _sanitize_dna(seq: str) -> str: | |
| seq = seq.upper() | |
| return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq) | |
| def _get_dna_sequence(assembly: str, chrom: str, start: int, end: int) -> str: | |
| """ | |
| Fetch DNA sequence from API based on assembly, chromosome, and coordinates. | |
| Uses ASSEMBLY_TO_API_URL_TEMPLATE to determine the API URL format for each assembly. | |
| Falls back to DEFAULT_API_URL_TEMPLATE if assembly is not in the mapping. | |
| """ | |
| if requests is None: | |
| raise ImportError( | |
| "requests is required for genome download. " | |
| "Install with: pip install requests" | |
| ) | |
| # Get API URL template for this assembly, or use default | |
| url_template = ASSEMBLY_TO_API_URL_TEMPLATE.get(assembly, DEFAULT_API_URL_TEMPLATE) | |
| # Format the URL with the provided parameters | |
| url = url_template.format(assembly=assembly, chrom=chrom, start=start, end=end) | |
| seq = requests.get(url).json()["dna"].upper() | |
| return seq | |
| def _pick_device(device: str | int | torch.device) -> torch.device: | |
| # Handle torch.device objects | |
| if isinstance(device, torch.device): | |
| return device | |
| # Handle integer device IDs (transformers pipeline convention) | |
| if isinstance(device, int): | |
| if device == -1: | |
| return torch.device("cpu") | |
| elif device >= 0: | |
| if torch.cuda.is_available(): | |
| return torch.device(f"cuda:{device}") | |
| else: | |
| return torch.device("cpu") | |
| else: | |
| raise ValueError(f"Invalid device integer: {device}") | |
| # Handle string device names | |
| if isinstance(device, str): | |
| d = device.lower() | |
| if d == "auto": | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| if d in ("cuda", "cpu", "mps"): | |
| return torch.device(d) | |
| raise ValueError( | |
| "device must be one of: 'auto', 'cpu', 'cuda', 'mps', or an integer" | |
| ) | |
| raise ValueError( | |
| f"device must be a string, integer, or torch.device, got {type(device)}" | |
| ) | |
| def _softmax_last(x: np.ndarray) -> np.ndarray: | |
| x = x - x.max(axis=-1, keepdims=True) | |
| ex = np.exp(x) | |
| return ex / ex.sum(axis=-1, keepdims=True) | |
| def _plot_tracks_fillbetween( | |
| tracks: dict[str, np.ndarray], | |
| chrom: str | None, | |
| start: int, | |
| end: int, | |
| assembly: str | None, | |
| height: float = 1.0, | |
| figsize_x: float = 20.0, | |
| ): | |
| if plt is None: | |
| raise ImportError( | |
| "matplotlib is required for plotting. Install with: pip install matplotlib" | |
| ) | |
| n = len(tracks) | |
| if n == 0: | |
| raise ValueError("No tracks to plot.") | |
| fig, axes = plt.subplots(n, 1, figsize=(figsize_x, height * n), sharex=True) | |
| if n == 1: | |
| axes = [axes] | |
| any_track = next(iter(tracks.values())) | |
| x = np.linspace(start, end, num=len(any_track), endpoint=False) | |
| # Define color schemes | |
| # BigWig tracks: use blue/gray tones | |
| bigwig_color = "#4A90E2" # Blue | |
| for ax, (title, y) in zip(axes, tracks.items()): | |
| # Determine color based on track type | |
| if title in BED_ELEMENT_COLORS: | |
| color = BED_ELEMENT_COLORS[title] | |
| else: | |
| color = bigwig_color | |
| ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0) | |
| ax.plot(x, y, color=color, linewidth=0.8) | |
| ax.set_title(title, fontsize=10, loc="left") | |
| ax.grid(alpha=0.2) | |
| ax.set_yticks([]) | |
| # minimal "despine" | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| label = f"{chrom}:{start}-{end}" if chrom is not None else f"{start}-{end}" | |
| if assembly is not None: | |
| label += f" ({assembly})" | |
| axes[-1].set_xlabel(label) | |
| plt.tight_layout() | |
| return fig, axes | |
| class NTv3TracksOutput: | |
| bigwig_tracks_logits: np.ndarray # (L_pred, T) | |
| bed_tracks_logits: np.ndarray # (L_pred, E, C) | |
| mlm_logits: np.ndarray | |
| chrom: str | None = None | |
| start: int | None = None | |
| end: int | None = None | |
| species: str | None = None | |
| assembly: str | None = None | |
| bigwig_track_names: list[str] | None = ( | |
| None # from cfg.bigwigs_per_species[species] | |
| ) | |
| bed_element_names: list[str] | None = None | |
| window_len: int | None = None | |
| pred_start: int | None = None | |
| pred_end: int | None = None | |
| class NTv3TracksPipeline(Pipeline): | |
| def __init__( | |
| self, | |
| model: str | torch.nn.Module, | |
| tokenizer: str | Any | None = None, | |
| trust_remote_code: bool = True, | |
| token: str | None = None, | |
| default_species: str = "human", | |
| genome_cache_dir: str | Path = "~/.cache/ntv3/genomes", | |
| device: str = "auto", | |
| mps_force_cpu: bool = True, | |
| mps_force_cpu_length: int = 16384, | |
| verbose: bool = True, | |
| # Your notebook uses these constants for "middle 37.5%" prediction span | |
| pred_center_fraction: float = 0.375, | |
| pred_center_offset_fraction: float = 0.3125, | |
| **kwargs: Any, | |
| ): | |
| self.model_id = model if isinstance(model, str) else None | |
| self.default_species = default_species | |
| self.genome_cache_dir = Path(genome_cache_dir) | |
| self.mps_force_cpu = bool(mps_force_cpu) | |
| self.mps_force_cpu_length = int(mps_force_cpu_length) | |
| self.verbose = bool(verbose) | |
| self.pred_center_fraction = float(pred_center_fraction) | |
| self.pred_center_offset_fraction = float(pred_center_offset_fraction) | |
| if isinstance(model, str): | |
| self.config = AutoConfig.from_pretrained( | |
| model, trust_remote_code=trust_remote_code, token=token | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| model, trust_remote_code=trust_remote_code, token=token | |
| ) | |
| else: | |
| self.model = model | |
| self.config = getattr(model, "config", None) | |
| if tokenizer is None: | |
| if not self.model_id: | |
| raise ValueError( | |
| "If passing a model module, pass tokenizer explicitly." | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_id, trust_remote_code=trust_remote_code, token=token, | |
| ) | |
| elif isinstance(tokenizer, str): | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer, trust_remote_code=trust_remote_code, token=token | |
| ) | |
| else: | |
| self.tokenizer = tokenizer | |
| # Extract model_id from config if not already set | |
| # (following ntv3_gff_pipeline.py pattern) | |
| if self.model_id is None and self.config is not None: | |
| self.model_id = getattr(self.config, "_name_or_path", None) or getattr( | |
| self.config, "name_or_path", None | |
| ) | |
| # bed names (your notebooks refer to bed_element_names) | |
| self.bed_element_names = getattr( | |
| self.config, "bed_elements_names", None | |
| ) or getattr(self.config, "bed_element_names", None) | |
| self._target_device = _pick_device(device) | |
| self.model.to(self._target_device) | |
| self.model.eval() | |
| super().__init__( | |
| model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs | |
| ) | |
| def available_bigwig_track_names(self, species: str | None = None) -> list[str]: | |
| """ | |
| Return BigWig track IDs for the assembly corresponding to `species`. | |
| No model forward pass. | |
| """ | |
| if species not in self.config.bigwigs_per_species: | |
| raise ValueError( | |
| f"Species {species} not found in checkpoint config. " | |
| f"Available: {list(self.config.bigwigs_per_species.keys())}" | |
| ) | |
| return list(self.config.bigwigs_per_species[species]) | |
| def available_bed_element_names(self, species: str | None = None) -> list[str]: | |
| """ | |
| Return BED element names available in this checkpoint for the given species. | |
| Filters elements based on species-specific training data availability. | |
| Parameters | |
| ---------- | |
| species : str | None | |
| Species name (e.g., "human", "mouse"). If None, returns all elements | |
| without filtering (for backward compatibility). | |
| Returns | |
| ------- | |
| list[str] | |
| Filtered list of BED element names available for this species | |
| """ | |
| all_elements = list(self.bed_element_names or []) | |
| if species is None: | |
| return all_elements | |
| return _filter_bed_elements_by_species(all_elements, species) | |
| def _sanitize_parameters(self, **kwargs): | |
| return {}, {}, {} | |
| def _get_model_device(self) -> torch.device: # noqa: CCE001 | |
| return next(self.model.parameters()).device | |
| def _resolve_species_and_assembly(self, inputs: dict[str, Any]) -> tuple[str, str]: | |
| species = inputs.get("species", self.default_species) | |
| if species not in SPECIES_TO_ASSEMBLY: | |
| supported = sorted(SPECIES_TO_ASSEMBLY.keys()) | |
| raise ValueError( | |
| f"Unsupported species='{species}'. " f"Supported species: {supported}" | |
| ) | |
| assembly = SPECIES_TO_ASSEMBLY[species] | |
| cfg_species = list(self.config.bigwigs_per_species.keys()) | |
| if species not in cfg_species: | |
| raise ValueError( | |
| f"Species '{species}' is not available in this checkpoint. " | |
| f"Available species: {cfg_species}" | |
| ) | |
| return species, assembly | |
| def _maybe_force_cpu_for_mps_long( # noqa: CCE001 | |
| self, input_ids_cpu: torch.Tensor | |
| ) -> torch.device: | |
| dev = self._get_model_device() | |
| if self.mps_force_cpu and dev.type == "mps": | |
| seq_len = int(input_ids_cpu.shape[-1]) | |
| if seq_len >= self.mps_force_cpu_length: | |
| if self.verbose: | |
| print( | |
| f"[NTv3TracksPipeline] MPS detected and input is long " | |
| f"(tokens={seq_len}). Switching model + inputs to CPU " | |
| "for this run." | |
| ) | |
| self.model.to("cpu") | |
| self.model.eval() | |
| return torch.device("cpu") | |
| return dev | |
| def preprocess(self, inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]: | |
| species, assembly = self._resolve_species_and_assembly(inputs) | |
| # Resolve sequence | |
| if "seq" in inputs and inputs["seq"] is not None: | |
| seq = _sanitize_dna(inputs["seq"]) | |
| chrom = None | |
| start = 0 | |
| end = len(seq) | |
| window_len = len(seq) | |
| else: | |
| chrom = inputs["chrom"] | |
| start = int(inputs["start"]) | |
| end = int(inputs["end"]) | |
| window_len = end - start | |
| seq = _get_dna_sequence(assembly, chrom, start, end) | |
| seq = _sanitize_dna(seq) | |
| # Tokenize with padding | |
| batch = self.tokenizer( | |
| [seq], | |
| add_special_tokens=False, | |
| padding=True, | |
| pad_to_multiple_of=128, | |
| return_tensors="pt", | |
| ) | |
| input_ids_cpu = batch["input_ids"] | |
| # MPS-long fallback decision | |
| device = self._maybe_force_cpu_for_mps_long(input_ids_cpu) | |
| # Move inputs | |
| input_ids = input_ids_cpu.to(device) | |
| # Species tokenization - match batch size | |
| batch_size = input_ids.shape[0] | |
| species_ids = self.model.encode_species([species] * batch_size) | |
| species_ids_tensor = species_ids.to(device) | |
| # Prediction interval (not used for slicing logits, just x-axis) | |
| pred_start = start + int(window_len * self.pred_center_offset_fraction) | |
| pred_end = pred_start + int(window_len * self.pred_center_fraction) | |
| # ✅ The source of truth for track IDs/names (your note) | |
| bigwig_track_names = list(self.config.bigwigs_per_species[species]) | |
| return { | |
| "input_ids": input_ids, | |
| "species_ids": species_ids_tensor, | |
| "meta": { | |
| "chrom": chrom, | |
| "start": start, | |
| "end": end, | |
| "species": species, | |
| "assembly": assembly, | |
| "window_len": window_len, | |
| "pred_start": pred_start, | |
| "pred_end": pred_end, | |
| "bigwig_track_names": bigwig_track_names, | |
| }, | |
| } | |
| # prevent Pipeline from moving tensors to its own device | |
| def forward(self, model_inputs, **forward_params): | |
| return self._forward(model_inputs, **forward_params) | |
| def postprocess( | |
| self, model_outputs: dict[str, Any], **kwargs: Any | |
| ) -> NTv3TracksOutput: | |
| # Extract model_output and meta from the dict returned by _forward | |
| if isinstance(model_outputs, dict) and "model_output" in model_outputs: | |
| model_out = model_outputs["model_output"] | |
| meta = model_outputs.get("meta", {}) | |
| else: | |
| # Fallback for direct ModelOutput (shouldn't happen with current code) | |
| model_out = model_outputs | |
| meta = {} | |
| def to_np(x): | |
| return x.detach().float().cpu().numpy() | |
| # Access model output - ModelOutput objects support both dict and attribute access | |
| bigwig_np = to_np(model_out["bigwig_tracks_logits"]) | |
| bed_np = to_np(model_out["bed_tracks_logits"]) | |
| mlm_np = to_np(model_out["logits"]) | |
| # Normalize shapes to remove batch/(optional assembly) dims | |
| if bigwig_np.ndim == 3: | |
| bigwig_np = bigwig_np[0] # (L, T) | |
| elif bigwig_np.ndim == 4: | |
| bigwig_np = bigwig_np[0, 0] # (L, T) if (B, A, L, T) | |
| else: | |
| raise ValueError(f"Unexpected bigwig_tracks_logits ndim: {bigwig_np.ndim}") | |
| if bed_np.ndim == 4: | |
| bed_np = bed_np[0] # (L, E, C) | |
| elif bed_np.ndim == 5: | |
| bed_np = bed_np[0, 0] # (L, E, C) if (B, A, L, E, C) | |
| else: | |
| raise ValueError(f"Unexpected bed_tracks_logits ndim: {bed_np.ndim}") | |
| if mlm_np.ndim == 3: | |
| mlm_np = mlm_np[0] | |
| # Filter BED elements based on species | |
| species = meta.get("species") | |
| all_bed_element_names = self.bed_element_names or [] | |
| if species and all_bed_element_names: | |
| filtered_bed_element_names = _filter_bed_elements_by_species( | |
| all_bed_element_names, species | |
| ) | |
| # Filter bed_tracks_logits to only include elements available for this species | |
| if filtered_bed_element_names != all_bed_element_names: | |
| # Create mapping from filtered element names to original indices | |
| element_indices = [ | |
| all_bed_element_names.index(elem) | |
| for elem in filtered_bed_element_names | |
| if elem in all_bed_element_names | |
| ] | |
| if element_indices: | |
| # bed_np shape is (L, E, C) where E is number of elements | |
| bed_np = bed_np[:, element_indices, :] | |
| # Update filtered list to only include elements that were found | |
| filtered_bed_element_names = [ | |
| elem | |
| for elem in filtered_bed_element_names | |
| if elem in all_bed_element_names | |
| ] | |
| else: | |
| filtered_bed_element_names = all_bed_element_names | |
| return NTv3TracksOutput( | |
| bigwig_tracks_logits=bigwig_np, | |
| bed_tracks_logits=bed_np, | |
| mlm_logits=mlm_np, | |
| chrom=meta.get("chrom"), | |
| start=meta.get("start"), | |
| end=meta.get("end"), | |
| species=meta.get("species"), | |
| assembly=meta.get("assembly"), | |
| bigwig_track_names=meta.get("bigwig_track_names"), | |
| bed_element_names=filtered_bed_element_names, | |
| window_len=meta.get("window_len"), | |
| pred_start=meta.get("pred_start"), | |
| pred_end=meta.get("pred_end"), | |
| ) | |
| def _forward(self, model_inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]: | |
| meta = model_inputs.pop("meta") | |
| if self.verbose: | |
| print(f"Running on device: {self._get_model_device()}") | |
| with torch.no_grad(): | |
| out = self.model( | |
| input_ids=model_inputs["input_ids"], | |
| species_ids=model_inputs["species_ids"], | |
| ) | |
| # Return a dict containing the model output and meta separately | |
| # since ModelOutput objects are immutable | |
| return {"model_output": out, "meta": meta} | |
| def __call__( | |
| self, | |
| inputs, | |
| *args, | |
| plot: bool = False, | |
| tracks_to_plot: dict[str, str] | None = None, # title -> track_id (ENCSR...) | |
| elements_to_plot: list[str] | None = None, # element names | |
| plot_height: float = 1.0, | |
| plot_figsize_x: float = 20.0, | |
| **kwargs, | |
| ): | |
| """ | |
| One-step call that can optionally plot and always returns NTv3TracksOutput. | |
| """ | |
| out: NTv3TracksOutput = super().__call__(inputs, *args, **kwargs) | |
| if plot: | |
| if out.bigwig_track_names is None: | |
| raise ValueError( | |
| "bigwig_track_names missing; expected " | |
| "cfg.bigwigs_per_species[species]." | |
| ) | |
| if out.bed_element_names is None: | |
| raise ValueError("bed element names missing from config.") | |
| tracks_to_plot = tracks_to_plot or {} | |
| elements_to_plot = elements_to_plot or [] | |
| bigwig_names = out.bigwig_track_names | |
| bed_element_names = out.bed_element_names | |
| # Validate | |
| missing_tracks = [ | |
| tid for tid in tracks_to_plot.values() if tid not in bigwig_names | |
| ] | |
| if missing_tracks: | |
| raise ValueError( | |
| f"The following tracks are not available in " | |
| f"bigwig_names: {missing_tracks}\n" | |
| f"First 50 available: {bigwig_names[:50]}" | |
| f"{'...' if len(bigwig_names) > 50 else ''}" | |
| ) | |
| missing_elements = [ | |
| e for e in elements_to_plot if e not in bed_element_names | |
| ] | |
| if missing_elements: | |
| first_50 = bed_element_names[:50] | |
| ellipsis = "..." if len(bed_element_names) > 50 else "" | |
| raise ValueError( | |
| f"The following elements are not available in " | |
| f"bed_element_names: {missing_elements}\n" | |
| f"First 50 available: {first_50}{ellipsis}" | |
| ) | |
| # Build bigwig tracks dict (title -> y) | |
| bigwig_tracks: dict[str, np.ndarray] = {} | |
| bigwig = out.bigwig_tracks_logits # (L_pred, T) | |
| for title, track_id in tracks_to_plot.items(): | |
| track_idx = bigwig_names.index(track_id) | |
| bigwig_tracks[title] = bigwig[:, track_idx] | |
| # Bed positive class probabilities (title -> y) | |
| bed_probs: dict[str, np.ndarray] = {} | |
| probs = _softmax_last(out.bed_tracks_logits) # (L_pred, E, C) | |
| for element_name in elements_to_plot: | |
| element_idx = bed_element_names.index(element_name) | |
| bed_probs[element_name] = probs[:, element_idx, 1] | |
| all_tracks = {**bigwig_tracks, **bed_probs} | |
| plot_start = int(out.pred_start or 0) | |
| plot_end = int( | |
| out.pred_end or (plot_start + len(next(iter(all_tracks.values())))) | |
| ) | |
| _plot_tracks_fillbetween( | |
| all_tracks, | |
| chrom=out.chrom, | |
| start=plot_start, | |
| end=plot_end, | |
| assembly=out.assembly, | |
| height=plot_height, | |
| figsize_x=plot_figsize_x, | |
| ) | |
| return out | |
| def load_ntv3_tracks_pipeline( | |
| model: str, | |
| device: str = "auto", | |
| **pipeline_kwargs: Any, | |
| ): | |
| """ | |
| Convenience helper to build an NTv3TracksPipeline for any NTv3 checkpoint. | |
| Parameters | |
| ---------- | |
| model: | |
| Checkpoint id, e.g. "InstaDeepAI/NTv3_100M", "InstaDeepAI/NTv3_650M", ... | |
| device: | |
| "auto", "cpu", "cuda", "mps" | |
| pipeline_kwargs: | |
| Extra kwargs passed to NTv3TracksPipeline | |
| (default_species, genome_cache_dir, etc.). | |
| """ | |
| pipe = NTv3TracksPipeline( | |
| model=model, | |
| trust_remote_code=True, | |
| device=device, | |
| **pipeline_kwargs, | |
| ) | |
| return pipe | |