import matplotlib.pyplot as plt
import numpy as np
import pytensor.tensor as pt
import pytensor.xtensor as ptx
from pymc_marketing.mmm.transformers import batched_convolution, ConvMode
spends_np = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], dtype=float)
w_np = np.array([0.75, 0.25, 0.125, 0.125])
spends = ptx.as_xtensor(pt.as_tensor_variable(spends_np), dims=('time',))
w = ptx.as_xtensor(pt.as_tensor_variable(w_np), dims=('time_kernel',))
x = np.arange(-5, 6)
ax = plt.subplot(111)
for mode in [ConvMode.Before, ConvMode.Overlap, ConvMode.After]:
    y = batched_convolution(spends, w, dim='time', kernel_dim='time_kernel', mode=mode).eval()
    suffix = "\n(default)" if mode == ConvMode.After else ""
    plt.plot(x, y, label=f'{mode.value}{suffix}')
plt.xlabel('time since spend', fontsize=12)
plt.ylabel('f(time since spend)', fontsize=12)
plt.title("1 spend at time 0 and w = [0.75, 0.25, 0.125, 0.125]", fontsize=14)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()