Source code for modalysis.core.plots.mean_methylation

"""Mean methylation line-plot generation across regions and chromosomes."""

import csv
import logging
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns

from modalysis.core.gene_regions import (
    GeneRegionsByChromosome,
    RegionList,
    build_gene_regions,
    find_genes_at_position,
    parse_gff,
)
from modalysis.core.plots.label_format import format_modification_label

logger = logging.getLogger(__name__)


RegionAccum = dict[tuple[str, str], list[int]]


[docs] def _find_overlapping_regions( position: int, region_list: RegionList, starts_list: list[int], ) -> bool: """Check if a position overlaps with any regions using binary search. Returns True if the position falls within at least one region. A position overlaps a region if region_start <= position < region_end. """ return bool(find_genes_at_position(position, region_list, starts_list))
[docs] def _accumulate_pileup( merged_pileup_path: str, regions: GeneRegionsByChromosome, ) -> RegionAccum: """Read a merged pileup file and accumulate n_valid_cov and n_mod per (chromosome, region). Returns: dict: (chromosome, region_name) -> [sum_n_valid_cov, sum_n_mod] """ accum: RegionAccum = {} input_file = open(merged_pileup_path, newline="") reader = csv.reader(input_file, delimiter="\t") header = next(reader) logger.debug("Merged pileup header: %s", header) num_rows = 0 num_assigned = 0 for row in reader: chromosome = row[0] start = int(row[1]) n_valid_cov = int(row[4]) n_mod = int(row[5]) num_rows += 1 if chromosome not in regions: continue chrom_regions = regions[chromosome] for region_name in ("promoter", "body", "enhancer"): if _find_overlapping_regions( start, chrom_regions[region_name], chrom_regions[f"{region_name}_starts"], ): key = (chromosome, region_name) if key not in accum: accum[key] = [0, 0] accum[key][0] += n_valid_cov accum[key][1] += n_mod num_assigned += 1 input_file.close() logger.info( "Accumulated pileup %s: %s rows read, %s region assignments.", merged_pileup_path, num_rows, num_assigned, ) return accum
[docs] def plot_mean_methylation( gff_path: str, merged_pileup_paths: list[str], labels: list[str], output_path: str, output_name: str, y_min: float = 0.0, y_max: float = 0.1, chromosome_order: list[str] | None = None, plot_title: str | None = None, ) -> None: """Generate region-grouped chromosome methylation line plots.""" output_file_path = (Path(output_path) / output_name).with_suffix(".png") logger.info( "Plotting mean methylation. GFF: %s, Pileups: %s, Output: %s", gff_path, merged_pileup_paths, output_file_path, ) # Step 1: Parse GFF and build region boundaries genes_by_chromosome = parse_gff(gff_path) regions = build_gene_regions(genes_by_chromosome) # Collect all chromosomes, sorted unless explicit order was provided. all_chromosomes = sorted(genes_by_chromosome.keys()) if chromosome_order: chromosome_by_upper = {chrom.upper(): chrom for chrom in all_chromosomes} ordered = [] seen = set() for chrom in chromosome_order: normalized = chrom.strip().upper() if normalized in chromosome_by_upper: canonical = chromosome_by_upper[normalized] if canonical not in seen: ordered.append(canonical) seen.add(canonical) for chrom in all_chromosomes: if chrom not in seen: ordered.append(chrom) all_chromosomes = ordered logger.info("Chromosomes found: %s", all_chromosomes) # Step 2: For each merged pileup, accumulate and compute mean methylation region_names = ["promoter", "body", "enhancer"] num_chromosomes = len(all_chromosomes) # Build x-axis tick labels and positions x_labels = [] for region_name in region_names: for chrom in all_chromosomes: x_labels.append(chrom) x_positions = list(range(len(x_labels))) # Step 3: For each pileup file, compute Y values sns.set_theme(style="white") fig, ax = plt.subplots(figsize=(max(20, num_chromosomes * 3), 8)) for pileup_idx, (pileup_path, label) in enumerate(zip(merged_pileup_paths, labels)): accum = _accumulate_pileup(pileup_path, regions) display_label = format_modification_label(label) y_values = [] for region_name in region_names: for chrom in all_chromosomes: key = (chrom, region_name) if key in accum and accum[key][0] > 0: mean_meth = accum[key][1] / accum[key][0] else: mean_meth = 0.0 y_values.append(mean_meth) ax.plot( x_positions, y_values, marker="o", markersize=4, linewidth=1.5, label=display_label, ) # Step 4: Add vertical separator lines between regions for i in range(1, len(region_names)): separator_x = i * num_chromosomes - 0.5 ax.axvline(x=separator_x, color="gray", linestyle="--", linewidth=1.0) # Step 5: Add region group labels at the top for i, region_name in enumerate(region_names): center_x = (i * num_chromosomes + (i + 1) * num_chromosomes - 1) / 2 ax.text( center_x, 1.02, region_name.capitalize(), transform=ax.get_xaxis_transform(), ha="center", va="bottom", fontsize=14, fontweight="bold", ) # Step 6: Configure axes ax.set_xticks(x_positions) ax.set_xticklabels(x_labels, rotation=45, ha="right", fontsize=8) ax.set_ylabel("Mean Methylation", fontsize=12) ax.set_xlabel("Chromosome", fontsize=12) ax.set_title( plot_title or "Mean Methylation by Region and Chromosome", fontsize=16, pad=30 ) ax.legend(title="Modification", fontsize=10) ax.set_xlim(-0.5, len(x_positions) - 0.5) ax.set_ylim(y_min, y_max) ax.grid(False) plt.tight_layout() plt.savefig(output_file_path, dpi=150, bbox_inches="tight") plt.close(fig) logger.info("Saved plot to %s", output_file_path)