One of the main concerns with finite mixture models is how to decide the number of components. A rule of thumb is to begin with a relatively small number of components and then increase it in order to improve the model-fit evaluation. As usual, model-fit is evaluated using posterior-predictive checks, measures such as WAIC or LOO, and on the basis of the expertise of the modeler(s).
Let us compare the model for . In order to do this, we are going to fit the model four times, and we are going to save the trace and model objects for later use:
clusters = [3, 4, 5, 6]
models = []
traces = []
for cluster in clusters:
with pm.Model() as model:
p = pm.Dirichlet('p', a=np.ones(cluster))
means = pm.Normal('means',
mu=np.linspace(cs_exp.min(), cs_exp.max(),
cluster),
sd=10, shape=cluster,
transform=pm.distributions.transforms.ordered)
sd = pm.HalfNormal('sd', sd=10)
y = pm.NormalMixture('y', w=p, mu=means, sd=sd, observed=cs_exp)
trace = pm.sample(1000, tune=2000, random_seed=123)
traces.append(trace)
models.append(model)
To better display how K affects the inference, we are going to compare the fit of these models with the one obtained with az.plot_kde. We are also going to plot the Gaussian components of the mixture model:
_, ax = plt.subplots(2, 2, figsize=(11, 8), constrained_layout=True)
ax = np.ravel(ax)
x = np.linspace(cs_exp.min(), cs_exp.max(), 200)
for idx, trace_x in enumerate(traces):
x_ = np.array([x] * clusters[idx]).T
for i in range(50):
i_ = np.random.randint(0, len(trace_x))
means_y = trace_x['means'][i_]
p_y = trace_x['p'][i_]
sd = trace_x['sd'][i_]
dist = stats.norm(means_y, sd)
ax[idx].plot(x, np.sum(dist.pdf(x_) * p_y, 1), 'C0', alpha=0.1)
means_y = trace_x['means'].mean(0)
p_y = trace_x['p'].mean(0)
sd = trace_x['sd'].mean()
dist = stats.norm(means_y, sd)
ax[idx].plot(x, np.sum(dist.pdf(x_) * p_y, 1), 'C0', lw=2)
ax[idx].plot(x, dist.pdf(x_) * p_y, 'k--', alpha=0.7)
az.plot_kde(cs_exp, plot_kwargs={'linewidth':2, 'color':'k'}, ax=ax[idx])
ax[idx].set_title('K = {}'.format(clusters[idx]))
ax[idx].set_yticks([])
ax[idx].set_xlabel('x')
The Figure 6.8 shows a KDE plot of the data, black solid line together with the mean fit wider (blue) line and samples from the posterior semitransparent (blue) lines. Also the mean-Gaussian components are represented using a dashed black line. In the Figure 6.8, it seems that is too low, and 4, 5, or 6 could be a better choice.
Notice that the Gaussian mixture model shows two central peaks/bumps (more or less around 55-60) while the KDE predicts less marked (more flattened) peaks. Notice this is not necessary a bad fit of the Gaussian mixture model, since KDEs are generally tuned to provide smoother densities. Instead of a KDE, you could use a histogram but histograms are also methods to approximate densities. As we have already discussed in Chapter 5, Model Comparison you could try to compute predictive posterior plots of test quantities of interest and compute Bayesian p-values. Figure 6.9 shows an example of such a calculation and visualization:
ppc_mm = [pm.sample_posterior_predictive(traces[i], 1000, models[i])
for i in range(4)]
fig, ax = plt.subplots(2, 2, figsize=(10, 6), sharex=True, constrained_layout=True)
ax = np.ravel(ax)
def iqr(x, a=0):
return np.subtract(*np.percentile(x, [75, 25], axis=a))
T_obs = iqr(cs_exp)
for idx, d_sim in enumerate(ppc_mm):
T_sim = iqr(d_sim['y'][:100].T, 1)
p_value = np.mean(T_sim >= T_obs)
az.plot_kde(T_sim, ax=ax[idx])
ax[idx].axvline(T_obs, 0, 1, color='k', ls='--')
ax[idx].set_title(f'K = {clusters[idx]} p-value {p_value:.2f}')
ax[idx].set_yticks([])
From Figure 6.9, we can see that the number is good choice with a Bayesian p-value very close to 0.5. As we can see in the following DataFrame and in Figure 6.10, WAIC also spots
as the better model (among the evaluated ones):
comp = az.compare(dict(zip(clusters, traces)), method='BB-pseudo-BMA')
comp
waic |
pwaic |
dwaic |
weight |
se |
dse |
warning |
|
---|---|---|---|---|---|---|---|
6 |
10250 |
12.368 |
0 |
0.948361 |
62.7354 |
0 |
0 |
5 |
10259.7 |
10.3531 |
9.69981 |
0.0472388 |
61.3804 |
4.6348 |
0 |
4 |
10278.9 |
7.45718 |
28.938 |
0.00440011 |
60.7985 |
9.82746 |
0 |
3 |
10356.9 |
5.90559 |
106.926 |
3.19235e-13 |
60.9242 |
18.5501 |
0 |
Most often than not reading a plot is way easier than reading a table, so let's make a plot to spot how different models are according to WAIC. As we can see from Figure 6.10, while the model with six components has a lower WAIC than the rest but there is considerable overlap when we consider the estimated standard error ( se ), especially with regard to the model with five components:
az.plot_compare(comp)