import numpy as np
import matplotlib.pyplot as plt
import pytensor.tensor as pt
import pytensor.xtensor as ptx
from pymc_marketing.mmm.transformers import michaelis_menten

x_np = np.linspace(0, 100, 500)
x = ptx.as_xtensor(pt.as_tensor_variable(x_np), dims=('x',))
alpha_values = [5, 10, 15]  # Different values of alpha
lam_values = [25, 50, 75]  # Different values of lam

# Plot varying lam
plt.figure(figsize=(8, 6))
for lam in lam_values:
    y = michaelis_menten(x, alpha_values[0], lam).eval()
    plt.plot(x_np, y, label=f"lam={lam}")
plt.xlabel('Spend/Impressions (x)')
plt.ylabel('Contribution (y)')
plt.title('Michaelis-Menten Function (Varying lam)')
plt.legend()
plt.show()

# Plot varying alpha
plt.figure(figsize=(8, 6))
for alpha in alpha_values:
    y = michaelis_menten(x, alpha, lam_values[0]).eval()
    plt.plot(x_np, y, label=f"alpha={alpha}")
plt.xlabel('Spend/Impressions (x)')
plt.ylabel('Contribution (y)')
plt.title('Michaelis-Menten Function (Varying alpha)')
plt.legend()
plt.show()