import matplotlib.pyplot as plt
import numpy as np
import pytensor.tensor as pt
import pytensor.xtensor as ptx
from pymc_marketing.mmm.transformers import delayed_adstock
params = [
    (0.25, 0, False),
    (0.25, 5, False),
    (0.75, 5, False),
    (0.75, 5, True)
]
spend_np = np.zeros(15); spend_np[0] = 1
spend = ptx.as_xtensor(pt.as_tensor_variable(spend_np), dims=('time',))
x = np.arange(len(spend_np))
ax = plt.subplot(111)
for a, t, normalize in params:
    y = delayed_adstock(spend, alpha=a, theta=t, normalize=normalize, dim='time').eval()
    plt.plot(x, y, label=f'alpha = {a}\ntheta = {t}\nnormalize = {normalize}')
plt.xlabel('time since spend', fontsize=12)
plt.ylabel('f(time since spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.65, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()