plot_interactive#
Interactive Plotly plotting factory for MMM.
This module provides MMMPlotlyFactory, which creates interactive Plotly
visualizations from MMM summary data produced by MMMSummaryFactory.
The factory supports:
Contributions: Bar charts showing channel/control/seasonality contributions
ROAS: Return on Ad Spend analysis with confidence intervals
Posterior Predictive: Time series with HDI bands comparing actual vs predicted
Saturation Curves: Visualize diminishing returns per channel
Adstock Curves: Show carryover effects over time
Automatic faceting based on custom dimensions (e.g., geo, brand)
Both Pandas and Polars DataFrames via Narwhals
Examples#
Basic Usage via MMM Model
Access the plotting factory directly from a fitted MMM model:
>>> # Posterior predictive with actual vs predicted
>>> fig = mmm.plot_interactive.posterior_predictive()
>>> fig.show()
>>> # Channel contributions over time
>>> fig = mmm.plot_interactive.contributions()
>>> fig.show()
>>> # ROAS analysis aggregated by year
>>> fig = mmm.plot_interactive.roas(frequency="yearly")
>>> fig.show()
>>> # Saturation curves showing diminishing returns
>>> fig = mmm.plot_interactive.saturation_curves()
>>> fig.show()
>>> # Adstock curves showing carryover effects
>>> fig = mmm.plot_interactive.adstock_curves()
>>> fig.show()
Customizing Plots
Control faceting and styling with kwargs:
>>> # ROAS colored by date, grouped by channel
>>> fig = mmm.plot_interactive.roas(frequency="yearly", color="date", x="channel")
>>> fig.show()
>>> # Disable auto-faceting and manually set facet column
>>> fig = mmm.plot_interactive.contributions(
... facet_col="country", title="Channel Effects by Country"
... )
>>> fig.show()
>>> # Saturation curves faceted by brand
>>> fig = mmm.plot_interactive.saturation_curves(
... facet_row="brand",
... )
>>> fig.show()
Working with Filtered/Aggregated Data
Create custom factories with filtered or aggregated data:
>>> from pymc_marketing.mmm.summary import MMMSummaryFactory
>>> from pymc_marketing.mmm.plot_interactive import MMMPlotlyFactory
>>> # Aggregate multiple geos into one
>>> agg_data = mmm.data.aggregate_dims(
... dim="geo", values=["geo_a", "geo_b"], new_label="all_geos"
... )
>>> agg_summary = MMMSummaryFactory(agg_data, mmm)
>>> agg_factory = MMMPlotlyFactory(summary=agg_summary)
>>> fig = agg_factory.roas(frequency="yearly", color="channel", x="date")
>>> fig.show()
>>> # Filter to specific geo
>>> filtered_data = mmm.data.filter_dims(geo="geo_a")
>>> filtered_summary = MMMSummaryFactory(filtered_data, mmm, validate_data=False)
>>> filtered_factory = MMMPlotlyFactory(summary=filtered_summary)
>>> fig = filtered_factory.roas(frequency="yearly", color="channel", x="date")
>>> fig.show()
>>> # Filter by date range
>>> filtered_data = mmm.data.filter_dates(start_date="2024-01-01")
>>> filtered_summary = MMMSummaryFactory(filtered_data, mmm)
>>> filtered_factory = MMMPlotlyFactory(summary=filtered_summary)
>>> fig = filtered_factory.roas(frequency="quarterly", color="channel", x="date")
>>> fig.show()
Classes
|
Factory for creating interactive Plotly plots from MMM summary data. |