"""Gene-level methylation heatmap generation."""
import csv
import logging
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from modalysis.core.expression import parse_expression_field
from modalysis.core.gene_regions import (
GeneRegionsByChromosome,
build_gene_regions,
find_genes_at_position,
parse_gff,
)
from modalysis.core.plots.label_format import format_modification_label
logger = logging.getLogger(__name__)
EXPRESSION_PROFILES_TO_PLOT = ["UP", "DOWN"]
EFFECT_SIGNS = ["NON_NEGATIVE", "NEGATIVE"]
REGIONS = ["PROMOTER", "BODY", "ENHANCER"]
CombinationKey = tuple[str, str, str, str, str]
GeneSetsByCombination = dict[CombinationKey, set[str]]
GeneRegionKey = tuple[str, str]
GeneAccum = dict[GeneRegionKey, list[int]]
ManifestationExpressionMap = dict[str, str]
GeneExpressionMap = dict[str, dict[str, str]]
[docs]
def _collect_genes_by_combination(
annotated_dmr_paths: list[str],
manifestations: list[str],
modifications: list[str],
manifestation_to_expression_label: ManifestationExpressionMap,
gene_to_expression: GeneExpressionMap,
) -> GeneSetsByCombination:
"""Read annotated DMR files and collect the set of genes for each
(manifestation, expression_profile, effect_sign, modification, region)
combination.
Returns:
dict: key -> set of gene_ids
"""
genes_by_combination: GeneSetsByCombination = {}
num_processed_rows = 0
for dmr_path, manifestation, modification in zip(
annotated_dmr_paths, manifestations, modifications
):
normalized_manifestation = manifestation.strip().upper()
normalized_modification = modification.strip().upper()
if normalized_manifestation not in manifestation_to_expression_label:
raise ValueError(
"No expression label configured for manifestation '%s'"
% normalized_manifestation
)
expression_label = manifestation_to_expression_label[normalized_manifestation]
dmr_file = open(dmr_path, newline="")
dmr_reader = csv.DictReader(dmr_file, delimiter="\t")
for row in dmr_reader:
effect_size = float(row["EFFECT_SIZE"])
effect_sign = "NON_NEGATIVE" if effect_size >= 0 else "NEGATIVE"
for region in REGIONS:
genes_field = row[region].strip()
if not genes_field:
continue
genes = [gene.strip().upper() for gene in genes_field.split(",")]
for gene in genes:
if not gene:
continue
if gene not in gene_to_expression:
continue
expression_profile = gene_to_expression[gene].get(expression_label)
if expression_profile not in (
EXPRESSION_PROFILES_TO_PLOT + ["NDE"]
):
continue
key = (
normalized_manifestation,
expression_profile,
effect_sign,
normalized_modification,
region,
)
if key not in genes_by_combination:
genes_by_combination[key] = set()
genes_by_combination[key].add(gene)
num_processed_rows += 1
dmr_file.close()
logger.info("Collected gene sets from %s annotated DMR rows.", num_processed_rows)
return genes_by_combination
[docs]
def _accumulate_pileup_per_gene(
merged_pileup_path: str,
regions: GeneRegionsByChromosome,
) -> GeneAccum:
"""Read a merged pileup file and accumulate n_valid_cov and n_mod
per (gene_id, region_name).
Returns:
dict: (gene_id, region_name) -> [sum_n_valid_cov, sum_n_mod]
"""
gene_accum: GeneAccum = {}
input_file = open(merged_pileup_path, newline="")
reader = csv.reader(input_file, delimiter="\t")
header = next(reader)
logger.debug("Merged pileup header: %s", header)
num_rows = 0
num_assigned = 0
for row in reader:
chromosome = row[0]
start = int(row[1])
n_valid_cov = int(row[4])
n_mod = int(row[5])
num_rows += 1
if chromosome not in regions:
continue
chrom_regions = regions[chromosome]
for region_name in ("promoter", "body", "enhancer"):
gene_ids = find_genes_at_position(
start,
chrom_regions[region_name],
chrom_regions[f"{region_name}_starts"],
)
for gene_id in gene_ids:
key = (gene_id.upper(), region_name.upper())
if key not in gene_accum:
gene_accum[key] = [0, 0]
gene_accum[key][0] += n_valid_cov
gene_accum[key][1] += n_mod
num_assigned += 1
input_file.close()
logger.info(
"Accumulated pileup per gene from %s: %s rows read, %s gene-region assignments.",
merged_pileup_path,
num_rows,
num_assigned,
)
return gene_accum
[docs]
def plot_gene_heatmap(
annotated_dmr_paths: list[str],
manifestations: list[str],
modifications: list[str],
manifestation_labels: list[str],
expression_labels: list[str],
annotated_gff_path: str,
gff_path: str,
merged_pileup_paths: list[str],
pileup_manifestations: list[str],
pileup_modifications: list[str],
output_path: str,
output_name: str,
show_gene_labels: bool = False,
effect_signs: list[str] | None = None,
) -> None:
"""Render per-combination heatmaps using DMR-selected genes and pileup means."""
logger.info(
"Plotting gene heatmaps. Annotated GFF: %s, GFF: %s, Output dir: %s",
annotated_gff_path,
gff_path,
output_path,
)
# Validate input lengths
if len(annotated_dmr_paths) != len(manifestations):
raise ValueError(
"Number of annotated DMR paths (%d) must match number of manifestations (%d)"
% (len(annotated_dmr_paths), len(manifestations))
)
if len(annotated_dmr_paths) != len(modifications):
raise ValueError(
"Number of annotated DMR paths (%d) must match number of modifications (%d)"
% (len(annotated_dmr_paths), len(modifications))
)
if len(manifestation_labels) != len(expression_labels):
raise ValueError(
"Number of manifestation labels (%d) must match number of expression labels (%d)"
% (len(manifestation_labels), len(expression_labels))
)
if len(merged_pileup_paths) != len(pileup_manifestations):
raise ValueError(
"Number of merged pileup paths (%d) must match number of pileup manifestations (%d)"
% (len(merged_pileup_paths), len(pileup_manifestations))
)
if len(merged_pileup_paths) != len(pileup_modifications):
raise ValueError(
"Number of merged pileup paths (%d) must match number of pileup modifications (%d)"
% (len(merged_pileup_paths), len(pileup_modifications))
)
if effect_signs is None:
signs_to_plot = EFFECT_SIGNS
else:
signs_to_plot = []
for sign in effect_signs:
normalized_sign = sign.strip().upper()
if normalized_sign not in EFFECT_SIGNS:
raise ValueError("Unsupported effect sign: %s" % sign)
if normalized_sign not in signs_to_plot:
signs_to_plot.append(normalized_sign)
# Step 1: Build manifestation -> expression_label mapping
manifestation_to_expression_label: ManifestationExpressionMap = {}
for manifestation_label, expression_label in zip(
manifestation_labels, expression_labels
):
manifestation_to_expression_label[manifestation_label.strip().upper()] = (
expression_label.strip().upper()
)
# Step 2: Read annotated GFF for expression profiles
gene_to_expression: GeneExpressionMap = {}
gff_file = open(annotated_gff_path, newline="")
gff_reader = csv.DictReader(gff_file, delimiter="\t")
for row in gff_reader:
gene_id = row["GENE_ID"].strip().upper()
expression_by_label = parse_expression_field(row.get("EXPRESSION", ""))
gene_to_expression[gene_id] = expression_by_label
gff_file.close()
logger.info("Loaded expression data for %s genes.", len(gene_to_expression))
# Step 3: Collect genes by combination from annotated DMRs
genes_by_combination = _collect_genes_by_combination(
annotated_dmr_paths,
manifestations,
modifications,
manifestation_to_expression_label,
gene_to_expression,
)
# Step 4: Parse formatted GFF and build gene regions
genes_by_chromosome = parse_gff(gff_path)
regions = build_gene_regions(genes_by_chromosome)
# Step 5: Build pileup lookup: (manifestation, modification) -> pileup path
pileup_lookup: dict[tuple[str, str], str] = {}
for pileup_path, pileup_man, pileup_mod in zip(
merged_pileup_paths, pileup_manifestations, pileup_modifications
):
key = (pileup_man.strip().upper(), pileup_mod.strip().upper())
pileup_lookup[key] = pileup_path
logger.info("Pileup lookup keys: %s", list(pileup_lookup.keys()))
# Step 6: Accumulate per-gene pileup data for each unique (manifestation, modification)
# Cache so we don't re-read the same pileup file
gene_accum_cache: dict[tuple[str, str], GeneAccum] = {}
for pileup_key, pileup_path in pileup_lookup.items():
logger.info("Accumulating pileup for %s from %s", pileup_key, pileup_path)
gene_accum_cache[pileup_key] = _accumulate_pileup_per_gene(pileup_path, regions)
# Step 7: Build all matrices first to determine a shared color scale
heatmap_data = []
for manifestation_label in manifestation_labels:
normalized_manifestation = manifestation_label.strip().upper()
for expression_profile in EXPRESSION_PROFILES_TO_PLOT:
for effect_sign in signs_to_plot:
mods_for_manifestation = set()
for dmr_manifestation, modification in zip(
manifestations, modifications
):
if dmr_manifestation.strip().upper() == normalized_manifestation:
mods_for_manifestation.add(modification.strip().upper())
for normalized_modification in sorted(mods_for_manifestation):
all_genes = set()
for region in REGIONS:
combination_key = (
normalized_manifestation,
expression_profile,
effect_sign,
normalized_modification,
region,
)
genes = genes_by_combination.get(combination_key, set())
all_genes.update(genes)
if not all_genes:
logger.debug(
"No genes for %s %s %s %s, skipping.",
normalized_manifestation,
expression_profile,
effect_sign,
normalized_modification,
)
continue
sorted_genes = sorted(all_genes)
pileup_key = (normalized_manifestation, normalized_modification)
gene_accum = gene_accum_cache.get(pileup_key, {})
if not gene_accum:
logger.warning(
"No pileup data for %s. Available keys: %s",
pileup_key,
list(gene_accum_cache.keys()),
)
matrix = []
for gene in sorted_genes:
row_values = []
for region in REGIONS:
accum_key = (gene, region)
if accum_key in gene_accum and gene_accum[accum_key][0] > 0:
mean_mod = (
gene_accum[accum_key][1] / gene_accum[accum_key][0]
)
else:
mean_mod = 0.0
row_values.append(mean_mod)
matrix.append(row_values)
matrix_np = np.array(matrix)
display_modification = format_modification_label(
normalized_modification
)
title = "%s %s %s" % (
normalized_manifestation,
expression_profile.capitalize(),
display_modification,
)
file_name = "%s_%s_%s_%s_%s" % (
output_name,
normalized_manifestation,
expression_profile,
effect_sign,
normalized_modification,
)
heatmap_data.append(
{
"matrix": matrix_np,
"genes": sorted_genes,
"title": title,
"file_name": file_name,
}
)
if not heatmap_data:
logger.info("No heatmaps to generate (no gene combinations found).")
return
# Compute shared vmax across all matrices
global_vmax = max(entry["matrix"].max() for entry in heatmap_data)
if global_vmax == 0.0:
global_vmax = 1.0
logger.info("Shared color scale: vmin=0.0, vmax=%.6f", global_vmax)
# Step 8: Render all heatmaps with shared scale
num_plots = 0
for entry in heatmap_data:
matrix_np = entry["matrix"]
sorted_genes = entry["genes"]
title = entry["title"]
file_name = entry["file_name"]
if show_gene_labels:
row_height = 0.4
fig_width = 6
else:
row_height = 0.15
fig_width = 4
fig_height = max(2, len(sorted_genes) * row_height + 1.5)
sns.set_theme(style="whitegrid")
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
sns.heatmap(
matrix_np,
annot=False,
cmap="YlOrRd",
xticklabels=["Promoter", "Body", "Enhancer"],
yticklabels=sorted_genes if show_gene_labels else False,
ax=ax,
cbar_kws={"label": "Mean Modification"},
linewidths=0.3,
linecolor="white",
vmin=0.0,
vmax=global_vmax,
)
ax.set_title(
"%s (%d genes)" % (title, len(sorted_genes)),
fontsize=12,
pad=10,
)
ax.set_xlabel("Region", fontsize=10)
if show_gene_labels:
ax.set_ylabel("Gene", fontsize=10)
ax.tick_params(axis="y", labelsize=7)
else:
ax.set_ylabel("")
plt.tight_layout()
output_file_path = (Path(output_path) / file_name).with_suffix(".png")
plt.savefig(output_file_path, dpi=150, bbox_inches="tight")
plt.close(fig)
logger.info(
"Saved heatmap to %s (%s genes).",
output_file_path,
len(sorted_genes),
)
num_plots += 1
logger.info("Generated %s heatmap plots.", num_plots)