We first revert the labels from one-hot format back to lists of integers:
y_test1 = model.predict(X_test)
y_test1 = lb.fit_transform(np.round(y_test1))
y_test1 = np.argmax(y_test1, axis=1)
y_test = np.argmax(y_test, axis=1)
We will extract the indices of mislabeled images, and use them to retrieve the corresponding true and predicted labels:
import numpy as np
mislabeled_indices = np.arange(len(y_test))[y_test!=y_test1]
true_labels = np.asarray([y_test[i] for i in mislabeled_indices])
predicted_labels = np.asarray([y_test1[i] for i in mislabeled_indices])
print(mislabeled_indices)
print(true_labels)
print(predicted_labels)
The output is as follows, with NumPy arrays of the indices, true and predicted labels of the array of mislabeled images:
[ 1 8 56 97 117 186 188 192 198 202 230 260 291 294 323 335 337] [9 7 8 2 4 4 2 4 8 9 6 9 7 6 8 8 1] [3 9 5 0 9 1 1 9 1 3 0 3 8 8 1 3 2]
Let's count how many samples are mislabeled for each digit. We will store the counts into a list:
mislabeled_digit_counts = [len(true_labels[true_labels==i]) for i in range(10)]
Now, we will plot a bar chart of the ratio of mislabeled samples for each digit:
# Calculate the ratio of mislabeled samples
total_digit_counts = [len(y_test[y_test==i]) for i in range(10)]
mislabeled_ratio = [mislabeled_digit_counts[i]/total_digit_counts[i] for i in range(10)]
pd.DataFrame(mislabeled_ratio).plot(kind='bar')
plt.xticks(rotation=0)
plt.xlabel('Digit')
plt.ylabel('Mislabeled ratio')
plt.legend([])
plt.show()
This code creates a bar chart showing the ratio of each digit mislabeled by our model:
From the preceding figure, we see that the digit 8 is the most easily mis-recognized digit by our model. Let's find out why.