The Dirichlet distribution lives in the simplex, which is like an n-dimensional triangle; a 1-simplex is a line, a 2-simplex is a triangle, a 3-simplex a tetrahedron, and so on. Why a simplex? Intuitively, because the output of this distribution is a -length vector, whose elements are restricted to be zero or larger than zero and sum up to 1. As we said, the Dirichlet distribution is the generalization of the beta distribution. Thus, a good way to understand the former is to compare it to the latter. We use the beta for two outcome problems: one with probability and the other . As we can, see . The beta returns a two-element vector , but in practice, we omit , as the outcome is entirely determined once we know . If we want to extend the beta distribution to three outcomes, we need a three-element vector , where each element is larger than zero and and thus . We could use three scalars to parameterize such distribution and we may call them , , and , however, we could easily run out of Greek letters, as there are only 24 of them; instead, we can just use a vector named α with length , where is the number of outcomes. Note that we can think of the beta and Dirichlet as distributions over proportions. To get an idea of this distribution, pay attention to Figure 6.4 and try to relate each triangular subplot to a beta distribution with similar parameters:
Now that we have a better grasp of the Dirichlet distribution, we have all the elements to build mixture models. One way to visualize them is as a -side coin flip model on top of a Gaussian estimation model. Using Kruschke-style diagrams:
The rounded-corner box is indicating that we have -components and the categorical variables decide which of them we use to describe a given data point.
This model (assuming clusters = 2) can be implemented using PyMC3, as follows:
with pm.Model() as model_kg:
p = pm.Dirichlet('p', a=np.ones(clusters))
z = pm.Categorical('z', p=p, shape=len(cs_exp))
means = pm.Normal('means', mu=cs_exp.mean(), sd=10, shape=clusters)
sd = pm.HalfNormal('sd', sd=10)
y = pm.Normal('y', mu=means[z], sd=sd, observed=cs_exp)
trace_kg = pm.sample()
If you run this code, you will find that it is very slow and the trace looks very bad (refer Chapter 8, Inference Engines, to learn more about diagnostics). The reason for such difficulties is that in model_kg we have explicitly included the latent variable in the model. One problem with this explicit approach is that sampling the discrete variable usually leads to slow mixing and ineffective exploration of the tails of the distribution. One way to solve these sampling problems is by reparametrizing the model.
Note that in a mixture model, the observed variable is modeled conditionally on the latent variable . That is, . We may think of the latent variable as a nuisance variable that we can marginalize and get . Luckily for us, PyMC3 includes a NormalMixture distribution that we can use to write a Gaussian mixture model in the following way:
clusters = 2
with pm.Model() as model_mg:
p = pm.Dirichlet('p', a=np.ones(clusters))
means = pm.Normal('means', mu=cs_exp.mean(), sd=10, shape=clusters)
sd = pm.HalfNormal('sd', sd=10)
y = pm.NormalMixture('y', w=p, mu=means, sd=sd, observed=cs_exp)
trace_mg = pm.sample(random_seed=123)
Let's use ArviZ to see how the trace looks like, we will compare this trace with the one obtained with model_mgp in the next section:
varnames = ['means', 'p']
az.plot_trace(trace_mg, varnames)
Let's also compute the summary for this model, we will compare this summary with the one obtained with model_mgp in the next section:
az.summary(trace_mg, varnames)
mean |
sd |
mc error |
hpd 3% |
hpd 97% |
eff_n |
r_hat |
|
means[0] |
52.12 |
5.35 |
2.14 |
46.24 |
57.68 |
1.0 |
25.19 |
means[1] |
52.14 |
5.33 |
2.13 |
46.23 |
57.65 |
1.0 |
24.52 |
p[0] |
0.50 |
0.41 |
0.16 |
0.08 |
0.92 |
1.0 |
68.91 |
p[1] |
0.50 |
0.41 |
0.16 |
0.08 |
0.92 |
1.0 |
68.91 |