Source code for modalysis.core.plots.gene_heatmap

"""Gene-level methylation heatmap generation."""

import csv
import logging
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from modalysis.core.expression import parse_expression_field
from modalysis.core.gene_regions import (
    GeneRegionsByChromosome,
    build_gene_regions,
    find_genes_at_position,
    parse_gff,
)
from modalysis.core.plots.label_format import format_modification_label

logger = logging.getLogger(__name__)

EXPRESSION_PROFILES_TO_PLOT = ["UP", "DOWN"]
EFFECT_SIGNS = ["NON_NEGATIVE", "NEGATIVE"]
REGIONS = ["PROMOTER", "BODY", "ENHANCER"]

CombinationKey = tuple[str, str, str, str, str]
GeneSetsByCombination = dict[CombinationKey, set[str]]
GeneRegionKey = tuple[str, str]
GeneAccum = dict[GeneRegionKey, list[int]]
ManifestationExpressionMap = dict[str, str]
GeneExpressionMap = dict[str, dict[str, str]]


[docs] def _collect_genes_by_combination( annotated_dmr_paths: list[str], manifestations: list[str], modifications: list[str], manifestation_to_expression_label: ManifestationExpressionMap, gene_to_expression: GeneExpressionMap, ) -> GeneSetsByCombination: """Read annotated DMR files and collect the set of genes for each (manifestation, expression_profile, effect_sign, modification, region) combination. Returns: dict: key -> set of gene_ids """ genes_by_combination: GeneSetsByCombination = {} num_processed_rows = 0 for dmr_path, manifestation, modification in zip( annotated_dmr_paths, manifestations, modifications ): normalized_manifestation = manifestation.strip().upper() normalized_modification = modification.strip().upper() if normalized_manifestation not in manifestation_to_expression_label: raise ValueError( "No expression label configured for manifestation '%s'" % normalized_manifestation ) expression_label = manifestation_to_expression_label[normalized_manifestation] dmr_file = open(dmr_path, newline="") dmr_reader = csv.DictReader(dmr_file, delimiter="\t") for row in dmr_reader: effect_size = float(row["EFFECT_SIZE"]) effect_sign = "NON_NEGATIVE" if effect_size >= 0 else "NEGATIVE" for region in REGIONS: genes_field = row[region].strip() if not genes_field: continue genes = [gene.strip().upper() for gene in genes_field.split(",")] for gene in genes: if not gene: continue if gene not in gene_to_expression: continue expression_profile = gene_to_expression[gene].get(expression_label) if expression_profile not in ( EXPRESSION_PROFILES_TO_PLOT + ["NDE"] ): continue key = ( normalized_manifestation, expression_profile, effect_sign, normalized_modification, region, ) if key not in genes_by_combination: genes_by_combination[key] = set() genes_by_combination[key].add(gene) num_processed_rows += 1 dmr_file.close() logger.info("Collected gene sets from %s annotated DMR rows.", num_processed_rows) return genes_by_combination
[docs] def _accumulate_pileup_per_gene( merged_pileup_path: str, regions: GeneRegionsByChromosome, ) -> GeneAccum: """Read a merged pileup file and accumulate n_valid_cov and n_mod per (gene_id, region_name). Returns: dict: (gene_id, region_name) -> [sum_n_valid_cov, sum_n_mod] """ gene_accum: GeneAccum = {} input_file = open(merged_pileup_path, newline="") reader = csv.reader(input_file, delimiter="\t") header = next(reader) logger.debug("Merged pileup header: %s", header) num_rows = 0 num_assigned = 0 for row in reader: chromosome = row[0] start = int(row[1]) n_valid_cov = int(row[4]) n_mod = int(row[5]) num_rows += 1 if chromosome not in regions: continue chrom_regions = regions[chromosome] for region_name in ("promoter", "body", "enhancer"): gene_ids = find_genes_at_position( start, chrom_regions[region_name], chrom_regions[f"{region_name}_starts"], ) for gene_id in gene_ids: key = (gene_id.upper(), region_name.upper()) if key not in gene_accum: gene_accum[key] = [0, 0] gene_accum[key][0] += n_valid_cov gene_accum[key][1] += n_mod num_assigned += 1 input_file.close() logger.info( "Accumulated pileup per gene from %s: %s rows read, %s gene-region assignments.", merged_pileup_path, num_rows, num_assigned, ) return gene_accum
[docs] def plot_gene_heatmap( annotated_dmr_paths: list[str], manifestations: list[str], modifications: list[str], manifestation_labels: list[str], expression_labels: list[str], annotated_gff_path: str, gff_path: str, merged_pileup_paths: list[str], pileup_manifestations: list[str], pileup_modifications: list[str], output_path: str, output_name: str, show_gene_labels: bool = False, effect_signs: list[str] | None = None, ) -> None: """Render per-combination heatmaps using DMR-selected genes and pileup means.""" logger.info( "Plotting gene heatmaps. Annotated GFF: %s, GFF: %s, Output dir: %s", annotated_gff_path, gff_path, output_path, ) # Validate input lengths if len(annotated_dmr_paths) != len(manifestations): raise ValueError( "Number of annotated DMR paths (%d) must match number of manifestations (%d)" % (len(annotated_dmr_paths), len(manifestations)) ) if len(annotated_dmr_paths) != len(modifications): raise ValueError( "Number of annotated DMR paths (%d) must match number of modifications (%d)" % (len(annotated_dmr_paths), len(modifications)) ) if len(manifestation_labels) != len(expression_labels): raise ValueError( "Number of manifestation labels (%d) must match number of expression labels (%d)" % (len(manifestation_labels), len(expression_labels)) ) if len(merged_pileup_paths) != len(pileup_manifestations): raise ValueError( "Number of merged pileup paths (%d) must match number of pileup manifestations (%d)" % (len(merged_pileup_paths), len(pileup_manifestations)) ) if len(merged_pileup_paths) != len(pileup_modifications): raise ValueError( "Number of merged pileup paths (%d) must match number of pileup modifications (%d)" % (len(merged_pileup_paths), len(pileup_modifications)) ) if effect_signs is None: signs_to_plot = EFFECT_SIGNS else: signs_to_plot = [] for sign in effect_signs: normalized_sign = sign.strip().upper() if normalized_sign not in EFFECT_SIGNS: raise ValueError("Unsupported effect sign: %s" % sign) if normalized_sign not in signs_to_plot: signs_to_plot.append(normalized_sign) # Step 1: Build manifestation -> expression_label mapping manifestation_to_expression_label: ManifestationExpressionMap = {} for manifestation_label, expression_label in zip( manifestation_labels, expression_labels ): manifestation_to_expression_label[manifestation_label.strip().upper()] = ( expression_label.strip().upper() ) # Step 2: Read annotated GFF for expression profiles gene_to_expression: GeneExpressionMap = {} gff_file = open(annotated_gff_path, newline="") gff_reader = csv.DictReader(gff_file, delimiter="\t") for row in gff_reader: gene_id = row["GENE_ID"].strip().upper() expression_by_label = parse_expression_field(row.get("EXPRESSION", "")) gene_to_expression[gene_id] = expression_by_label gff_file.close() logger.info("Loaded expression data for %s genes.", len(gene_to_expression)) # Step 3: Collect genes by combination from annotated DMRs genes_by_combination = _collect_genes_by_combination( annotated_dmr_paths, manifestations, modifications, manifestation_to_expression_label, gene_to_expression, ) # Step 4: Parse formatted GFF and build gene regions genes_by_chromosome = parse_gff(gff_path) regions = build_gene_regions(genes_by_chromosome) # Step 5: Build pileup lookup: (manifestation, modification) -> pileup path pileup_lookup: dict[tuple[str, str], str] = {} for pileup_path, pileup_man, pileup_mod in zip( merged_pileup_paths, pileup_manifestations, pileup_modifications ): key = (pileup_man.strip().upper(), pileup_mod.strip().upper()) pileup_lookup[key] = pileup_path logger.info("Pileup lookup keys: %s", list(pileup_lookup.keys())) # Step 6: Accumulate per-gene pileup data for each unique (manifestation, modification) # Cache so we don't re-read the same pileup file gene_accum_cache: dict[tuple[str, str], GeneAccum] = {} for pileup_key, pileup_path in pileup_lookup.items(): logger.info("Accumulating pileup for %s from %s", pileup_key, pileup_path) gene_accum_cache[pileup_key] = _accumulate_pileup_per_gene(pileup_path, regions) # Step 7: Build all matrices first to determine a shared color scale heatmap_data = [] for manifestation_label in manifestation_labels: normalized_manifestation = manifestation_label.strip().upper() for expression_profile in EXPRESSION_PROFILES_TO_PLOT: for effect_sign in signs_to_plot: mods_for_manifestation = set() for dmr_manifestation, modification in zip( manifestations, modifications ): if dmr_manifestation.strip().upper() == normalized_manifestation: mods_for_manifestation.add(modification.strip().upper()) for normalized_modification in sorted(mods_for_manifestation): all_genes = set() for region in REGIONS: combination_key = ( normalized_manifestation, expression_profile, effect_sign, normalized_modification, region, ) genes = genes_by_combination.get(combination_key, set()) all_genes.update(genes) if not all_genes: logger.debug( "No genes for %s %s %s %s, skipping.", normalized_manifestation, expression_profile, effect_sign, normalized_modification, ) continue sorted_genes = sorted(all_genes) pileup_key = (normalized_manifestation, normalized_modification) gene_accum = gene_accum_cache.get(pileup_key, {}) if not gene_accum: logger.warning( "No pileup data for %s. Available keys: %s", pileup_key, list(gene_accum_cache.keys()), ) matrix = [] for gene in sorted_genes: row_values = [] for region in REGIONS: accum_key = (gene, region) if accum_key in gene_accum and gene_accum[accum_key][0] > 0: mean_mod = ( gene_accum[accum_key][1] / gene_accum[accum_key][0] ) else: mean_mod = 0.0 row_values.append(mean_mod) matrix.append(row_values) matrix_np = np.array(matrix) display_modification = format_modification_label( normalized_modification ) title = "%s %s %s" % ( normalized_manifestation, expression_profile.capitalize(), display_modification, ) file_name = "%s_%s_%s_%s_%s" % ( output_name, normalized_manifestation, expression_profile, effect_sign, normalized_modification, ) heatmap_data.append( { "matrix": matrix_np, "genes": sorted_genes, "title": title, "file_name": file_name, } ) if not heatmap_data: logger.info("No heatmaps to generate (no gene combinations found).") return # Compute shared vmax across all matrices global_vmax = max(entry["matrix"].max() for entry in heatmap_data) if global_vmax == 0.0: global_vmax = 1.0 logger.info("Shared color scale: vmin=0.0, vmax=%.6f", global_vmax) # Step 8: Render all heatmaps with shared scale num_plots = 0 for entry in heatmap_data: matrix_np = entry["matrix"] sorted_genes = entry["genes"] title = entry["title"] file_name = entry["file_name"] if show_gene_labels: row_height = 0.4 fig_width = 6 else: row_height = 0.15 fig_width = 4 fig_height = max(2, len(sorted_genes) * row_height + 1.5) sns.set_theme(style="whitegrid") fig, ax = plt.subplots(figsize=(fig_width, fig_height)) sns.heatmap( matrix_np, annot=False, cmap="YlOrRd", xticklabels=["Promoter", "Body", "Enhancer"], yticklabels=sorted_genes if show_gene_labels else False, ax=ax, cbar_kws={"label": "Mean Modification"}, linewidths=0.3, linecolor="white", vmin=0.0, vmax=global_vmax, ) ax.set_title( "%s (%d genes)" % (title, len(sorted_genes)), fontsize=12, pad=10, ) ax.set_xlabel("Region", fontsize=10) if show_gene_labels: ax.set_ylabel("Gene", fontsize=10) ax.tick_params(axis="y", labelsize=7) else: ax.set_ylabel("") plt.tight_layout() output_file_path = (Path(output_path) / file_name).with_suffix(".png") plt.savefig(output_file_path, dpi=150, bbox_inches="tight") plt.close(fig) logger.info( "Saved heatmap to %s (%s genes).", output_file_path, len(sorted_genes), ) num_plots += 1 logger.info("Generated %s heatmap plots.", num_plots)