import matplotlib.pyplot as plt
import numpy as np
import pytensor.tensor as pt
import pytensor.xtensor as ptx
from pymc_marketing.mmm.transformers import tanh_saturation
params = [
    (0.75, 0.25),
    (0.75, 1.5),
    (1, 0.25),
    (1, 1),
    (1, 1.5),
]
x_np = np.linspace(0, 5, 100)
x = ptx.as_xtensor(pt.as_tensor_variable(x_np), dims=('x',))
ax = plt.subplot(111)
for b, c in params:
    y = tanh_saturation(x, b=b, c=c).eval()
    plt.plot(x_np, y, label=f'b = {b}\nc = {c}')
plt.xlabel('spend', fontsize=12)
plt.ylabel('f(spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()