Source code for modalysis.core.plots.dmr_dotplot

"""DMR position dotplot generation within promoter/body/enhancer regions."""

import csv
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import DefaultDict

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

from modalysis.core.expression import parse_expression_field
from modalysis.core.gene_regions import parse_gff
from modalysis.core.plots.label_format import format_modification_label

logger = logging.getLogger(__name__)

EXPRESSION_PROFILES = ["UP", "DOWN", "NDE"]
EFFECT_SIGNS = ["NON_NEGATIVE", "NEGATIVE"]
REGIONS = ["PROMOTER", "BODY", "ENHANCER"]

PROMOTER_SIZE = 1000
ENHANCER_SIZE = 1000
PROMOTER_WINDOW_SIZE = 100.0
BODY_WINDOW_SIZE = 10.0
ENHANCER_WINDOW_SIZE = 100.0

ManifestationExpressionMap = dict[str, str]
GeneExpressionMap = dict[str, dict[str, str]]
GeneCoords = dict[str, tuple[str, int, int]]
PositionKey = tuple[str, str, str, str, str, str]
DmrPositions = dict[PositionKey, list[float]]
RegionPoint = tuple[float, str]
GenePositions = dict[str, dict[str, list[float]]]


[docs] def _build_gene_coordinate_lookup(gff_path: str) -> GeneCoords: """Build a lookup from gene_id -> (chromosome, start, end) using the formatted GFF file. Returns: dict: gene_id (uppercase) -> (chromosome, start, end) """ gene_coords = {} input_file = open(gff_path, newline="") reader = csv.DictReader(input_file, delimiter="\t") for row in reader: gene_id = row["GENE_ID"].strip().upper() chromosome = row["CHROMOSOME"] start = int(row["START"]) end = int(row["END"]) gene_coords[gene_id] = (chromosome, start, end) input_file.close() logger.info("Built coordinate lookup for %s genes.", len(gene_coords)) return gene_coords
[docs] def _collect_dmr_positions( annotated_dmr_paths: list[str], manifestations: list[str], modifications: list[str], manifestation_to_expression_label: ManifestationExpressionMap, gene_to_expression: GeneExpressionMap, gene_coords: GeneCoords, ) -> DmrPositions: """Read annotated DMR files and collect the position of each DMR within its gene region. Returns: dict: (manifestation, expression_profile, effect_sign, modification, gene_id, region) -> list of float positions For PROMOTER: distance from gene start (-1000 = far upstream, 0 = gene start) For BODY: percentage (0-100) For ENHANCER: distance from gene end (0-1000) """ positions: DefaultDict[PositionKey, list[float]] = defaultdict(list) num_processed_rows = 0 for dmr_path, manifestation, modification in zip( annotated_dmr_paths, manifestations, modifications ): normalized_manifestation = manifestation.strip().upper() normalized_modification = modification.strip().upper() if normalized_manifestation not in manifestation_to_expression_label: raise ValueError( "No expression label configured for manifestation '%s'" % normalized_manifestation ) expression_label = manifestation_to_expression_label[normalized_manifestation] dmr_file = open(dmr_path, newline="") dmr_reader = csv.DictReader(dmr_file, delimiter="\t") for row in dmr_reader: dmr_start = int(row["START"]) dmr_end = int(row["END"]) dmr_midpoint = (dmr_start + dmr_end) / 2.0 effect_size = float(row["EFFECT_SIZE"]) effect_sign = "NON_NEGATIVE" if effect_size >= 0 else "NEGATIVE" for region in REGIONS: genes_field = row[region].strip() if not genes_field: continue genes = [gene.strip().upper() for gene in genes_field.split(",")] for gene in genes: if not gene: continue if gene not in gene_to_expression: continue expression_profile = gene_to_expression[gene].get(expression_label) if expression_profile not in EXPRESSION_PROFILES: continue if gene not in gene_coords: continue _chrom, gene_start, gene_end = gene_coords[gene] if region == "PROMOTER": # Distance from gene start: -1000 = far upstream, 0 = at gene start position = dmr_midpoint - gene_start position = max(-PROMOTER_SIZE, min(0, position)) elif region == "BODY": # Percentage through gene body: 0% = gene start, 100% = gene end gene_length = gene_end - gene_start if gene_length > 0: position = ( (dmr_midpoint - gene_start) / gene_length ) * 100.0 position = max(0, min(100, position)) else: position = 50.0 elif region == "ENHANCER": # Distance from gene end: 0 = at gene end, 2000 = far downstream position = dmr_midpoint - gene_end position = max(0, min(ENHANCER_SIZE, position)) else: continue key = ( normalized_manifestation, expression_profile, effect_sign, normalized_modification, gene, region, ) positions[key].append(position) num_processed_rows += 1 dmr_file.close() logger.info( "Collected DMR positions from %s annotated DMR rows.", num_processed_rows ) return dict(positions)
[docs] def _find_consensus_window( region_points: list[RegionPoint], window_size: float, min_genes: int, ) -> tuple[float, float] | None: """Find a window containing points from at least min_genes distinct genes.""" if not region_points or min_genes <= 0: return None sorted_points = sorted(region_points, key=lambda item: item[0]) gene_counts = defaultdict(int) left = 0 best = None for right, (right_pos, right_gene) in enumerate(sorted_points): gene_counts[right_gene] += 1 while right_pos - sorted_points[left][0] > window_size: left_gene = sorted_points[left][1] gene_counts[left_gene] -= 1 if gene_counts[left_gene] == 0: del gene_counts[left_gene] left += 1 distinct_genes = len(gene_counts) if distinct_genes >= min_genes: start = sorted_points[left][0] end = start + window_size candidate = (distinct_genes, start, end) if ( best is None or candidate[0] > best[0] or (candidate[0] == best[0] and candidate[1] < best[1]) ): best = candidate if best is None: return None _distinct_genes, start, end = best return (start, end)
[docs] def _render_dotplot( gene_positions: GenePositions, title: str, output_file_path: Path, show_gene_labels: bool = False, ) -> bool: """Render a single dotplot PNG. Args: gene_positions: dict of gene_id -> {region -> [positions]} where region is PROMOTER, BODY, or ENHANCER title: plot title string output_file_path: Path object for output file """ sorted_genes = sorted(gene_positions.keys()) num_genes = len(sorted_genes) if num_genes == 0: logger.debug("No genes for %s, skipping.", title) return False gene_to_y = {gene: i for i, gene in enumerate(sorted_genes)} sns.set_theme(style="whitegrid") row_height = 0.15 fig_height = max(2.3, num_genes * row_height + 1.5) fig_width = 12 if show_gene_labels else 10 fig, axes = plt.subplots( 1, 3, figsize=(fig_width, fig_height), sharey=False, gridspec_kw={"wspace": 0.15 if show_gene_labels else 0.08}, ) region_configs = [ { "name": "PROMOTER", "title": "Promoter (1000 bp)", "axis_min": -PROMOTER_SIZE, "axis_max": 0, "xlabel": "Distance from gene start (bp)", "window_size": PROMOTER_WINDOW_SIZE, "ticks": [-1000, -800, -600, -400, -200, 0], "x_pad": 1.0, }, { "name": "BODY", "title": "Gene Body (%)", "axis_min": 0, "axis_max": 100, "xlabel": "Position (%)", "window_size": BODY_WINDOW_SIZE, "ticks": [0, 20, 40, 60, 80, 100], "x_pad": 0.5, }, { "name": "ENHANCER", "title": "Enhancer (1000 bp)", "axis_min": 0, "axis_max": ENHANCER_SIZE, "xlabel": "Distance from gene end (bp)", "window_size": ENHANCER_WINDOW_SIZE, "ticks": [0, 200, 400, 600, 800, 1000], "x_pad": 1.0, }, ] dot_color = "#1f77b4" dot_size = 30 for ax_idx, (ax, config) in enumerate(zip(axes, region_configs)): region_name = config["name"] display_xmin = config["axis_min"] - config["x_pad"] display_xmax = config["axis_max"] + config["x_pad"] # Draw horizontal lines for each gene for gene in sorted_genes: y = gene_to_y[gene] ax.hlines( y, display_xmin, display_xmax, colors="#cccccc", linewidths=0.5, zorder=1, ) # Plot dots for DMR positions x_values = [] y_values = [] region_points = [] region_genes = set() for gene in sorted_genes: region_data = gene_positions[gene] if region_name in region_data: region_genes.add(gene) for pos in region_data[region_name]: x_values.append(pos) y_values.append(gene_to_y[gene]) region_points.append((pos, gene)) if x_values: ax.scatter( x_values, y_values, s=dot_size, color=dot_color, edgecolor="white", linewidth=0.5, zorder=2, ) if region_genes: min_genes = math.ceil(0.5 * len(region_genes)) consensus_window = _find_consensus_window( region_points, config["window_size"], min_genes, ) if consensus_window is not None: window_start, window_end = consensus_window window_start = max(window_start, config["axis_min"]) window_end = min(window_end, config["axis_max"]) if window_end > window_start: ax.add_patch( Rectangle( (window_start, -0.5), window_end - window_start, num_genes, fill=False, edgecolor="#d62728", linewidth=1.2, zorder=3, ) ) ax.set_xlim(display_xmin, display_xmax) ax.set_ylim(-0.5, num_genes - 0.5) ax.set_title(config["title"], fontsize=10, pad=8) ax.set_xlabel(config["xlabel"], fontsize=8) ax.set_xticks(config["ticks"]) ax.xaxis.grid(False) ax.yaxis.grid(False) # Y-axis: show gene IDs if requested (only on leftmost panel) if show_gene_labels and ax_idx == 0: ax.set_yticks(list(range(num_genes))) ax.set_yticklabels(sorted_genes, fontsize=5) ax.spines["left"].set_visible(True) ax.tick_params(axis="y", which="both", length=0, pad=2) else: ax.set_yticks([]) ax.set_ylabel("") ax.spines["left"].set_visible(False) # Only show bottom spine (and left spine when gene labels are on) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.tick_params(axis="x", labelsize=7) fig.suptitle( "%s (%d genes)" % (title, num_genes), fontsize=11, y=1.02, ) fig.subplots_adjust( top=0.88, left=0.15 if show_gene_labels else 0.05, wspace=0.15 if show_gene_labels else 0.08, ) plt.savefig(output_file_path, dpi=150, bbox_inches="tight") plt.close(fig) logger.info("Saved dotplot to %s (%d genes).", output_file_path, num_genes) return True
[docs] def plot_dmr_dotplot( annotated_dmr_paths: list[str], manifestations: list[str], modifications: list[str], manifestation_labels: list[str], expression_labels: list[str], annotated_gff_path: str, gff_path: str, output_path: str, output_name: str, show_gene_labels: bool = False, effect_signs: list[str] | None = None, ) -> None: """Render DMR position dotplots for each manifestation/expression/modification slice.""" logger.info( "Plotting DMR dotplots. Annotated GFF: %s, GFF: %s, Output dir: %s", annotated_gff_path, gff_path, output_path, ) # Validate input lengths if len(annotated_dmr_paths) != len(manifestations): raise ValueError( "Number of annotated DMR paths (%d) must match number of manifestations (%d)" % (len(annotated_dmr_paths), len(manifestations)) ) if len(annotated_dmr_paths) != len(modifications): raise ValueError( "Number of annotated DMR paths (%d) must match number of modifications (%d)" % (len(annotated_dmr_paths), len(modifications)) ) if len(manifestation_labels) != len(expression_labels): raise ValueError( "Number of manifestation labels (%d) must match number of expression labels (%d)" % (len(manifestation_labels), len(expression_labels)) ) if effect_signs is None: signs_to_plot = EFFECT_SIGNS else: signs_to_plot = [] for sign in effect_signs: normalized_sign = sign.strip().upper() if normalized_sign not in EFFECT_SIGNS: raise ValueError("Unsupported effect sign: %s" % sign) if normalized_sign not in signs_to_plot: signs_to_plot.append(normalized_sign) # Step 1: Build manifestation -> expression_label mapping manifestation_to_expression_label: ManifestationExpressionMap = {} for manifestation_label, expression_label in zip( manifestation_labels, expression_labels ): manifestation_to_expression_label[manifestation_label.strip().upper()] = ( expression_label.strip().upper() ) # Step 2: Read annotated GFF for expression profiles gene_to_expression: GeneExpressionMap = {} gff_file = open(annotated_gff_path, newline="") gff_reader = csv.DictReader(gff_file, delimiter="\t") for row in gff_reader: gene_id = row["GENE_ID"].strip().upper() expression_by_label = parse_expression_field(row.get("EXPRESSION", "")) gene_to_expression[gene_id] = expression_by_label gff_file.close() logger.info("Loaded expression data for %s genes.", len(gene_to_expression)) # Step 3: Build gene coordinate lookup from formatted GFF gene_coords = _build_gene_coordinate_lookup(gff_path) # Step 4: Collect DMR positions within gene regions positions = _collect_dmr_positions( annotated_dmr_paths, manifestations, modifications, manifestation_to_expression_label, gene_to_expression, gene_coords, ) # Step 5: Generate one plot per (manifestation, expression_profile, effect_sign, modification) num_plots = 0 for manifestation_label in manifestation_labels: normalized_manifestation = manifestation_label.strip().upper() # Find unique modifications for this manifestation mods_for_manifestation = set() for dmr_manifestation, modification in zip(manifestations, modifications): if dmr_manifestation.strip().upper() == normalized_manifestation: mods_for_manifestation.add(modification.strip().upper()) for expression_profile in EXPRESSION_PROFILES: for effect_sign in signs_to_plot: for normalized_modification in sorted(mods_for_manifestation): # Gather all genes and their positions for this combination gene_positions: DefaultDict[str, DefaultDict[str, list[float]]] = ( defaultdict(lambda: defaultdict(list)) ) for region in REGIONS: # Iterate through all genes in positions dict for key, pos_list in positions.items(): ( key_manifestation, key_expression, key_effect, key_modification, key_gene, key_region, ) = key if ( key_manifestation == normalized_manifestation and key_expression == expression_profile and key_effect == effect_sign and key_modification == normalized_modification and key_region == region ): gene_positions[key_gene][region].extend(pos_list) if not gene_positions: logger.debug( "No genes for %s %s %s %s, skipping.", normalized_manifestation, expression_profile, effect_sign, normalized_modification, ) continue display_modification = format_modification_label( normalized_modification ) title = "%s %s %s" % ( normalized_manifestation, expression_profile.capitalize(), display_modification, ) file_name = "%s_%s_%s_%s_%s" % ( output_name, normalized_manifestation, expression_profile, effect_sign, normalized_modification, ) output_file_path = (Path(output_path) / file_name).with_suffix( ".png" ) rendered = _render_dotplot( dict(gene_positions), title, output_file_path, show_gene_labels=show_gene_labels, ) if rendered: num_plots += 1 logger.info("Generated %s dotplot plots.", num_plots)