Confusion matrix

๐Ÿ€žย ๐ƒ๐ž๐ฆ๐ฒ๐ฌ๐ญ๐ข๐Ÿ๐ฒ๐ข๐ง๐  ๐‚๐จ๐ง๐Ÿ๐ฎ๐ฌ๐ข๐จ๐ง ๐Œ๐š๐ญ๐ซ๐ข๐ฑ

Confusion matrix

๐‡๐จ๐ฐ ๐ญ๐จ ๐’๐ฉ๐จ๐ญ ๐„๐ซ๐ซ๐จ๐ซ๐ฌ ๐ข๐ง ๐‚๐ฅ๐š๐ฌ๐ฌ๐ข๐Ÿ๐ข๐œ๐š๐ญ๐ข๐จ๐ง ๐Œ๐จ๐๐ž๐ฅ๐ฌ ๐Ÿ”

Understanding ๐’˜๐’‰๐’†๐’“๐’† ๐’š๐’๐’–๐’“ ๐’Ž๐’๐’…๐’†๐’ ๐’ˆ๐’†๐’•๐’” ๐’Š๐’• ๐’˜๐’“๐’๐’๐’ˆ ๐’Š๐’” ๐’„๐’“๐’–๐’„๐’Š๐’‚๐’ to improving performance. Lets learn, how a Confusion Matrix can help with this understanding.

๐–๐ก๐š๐ญ ๐ข๐ฌ ๐Œ๐ข๐ฌ๐œ๐ฅ๐š๐ฌ๐ฌ๐ข๐Ÿ๐ข๐œ๐š๐ญ๐ข๐จ๐ง?

๐‘€๐‘–๐‘ ๐‘๐‘™๐‘Ž๐‘ ๐‘ ๐‘–๐‘“๐‘–๐‘๐‘Ž๐‘ก๐‘–๐‘œ๐‘› โ„Ž๐‘Ž๐‘๐‘๐‘’๐‘›๐‘  ๐‘คโ„Ž๐‘’๐‘› ๐‘Ž ๐‘š๐‘œ๐‘‘๐‘’๐‘™ ๐‘๐‘Ÿ๐‘’๐‘‘๐‘–๐‘๐‘ก๐‘  ๐‘Ž๐‘› ๐‘–๐‘›๐‘๐‘œ๐‘Ÿ๐‘Ÿ๐‘’๐‘๐‘ก ๐‘๐‘™๐‘Ž๐‘ ๐‘  ๐‘™๐‘Ž๐‘๐‘’๐‘™ ๐‘“๐‘œ๐‘Ÿ ๐‘Ž ๐‘”๐‘–๐‘ฃ๐‘’๐‘› ๐‘–๐‘›๐‘๐‘ข๐‘ก. For example, classifying a cat ๐Ÿฑ as a dog ๐Ÿถ.

confusion matrix classification

๐–๐ก๐š๐ญโ€™๐ฌ ๐š ๐‚๐จ๐ง๐Ÿ๐ฎ๐ฌ๐ข๐จ๐ง ๐Œ๐š๐ญ๐ซ๐ข๐ฑ?

It is a table used to describe the ๐‘๐‘’๐‘Ÿ๐‘“๐‘œ๐‘Ÿ๐‘š๐‘Ž๐‘›๐‘๐‘’ ๐‘œ๐‘“ ๐‘Ž ๐‘๐‘™๐‘Ž๐‘ ๐‘ ๐‘–๐‘“๐‘–๐‘๐‘Ž๐‘ก๐‘–๐‘œ๐‘› ๐‘š๐‘œ๐‘‘๐‘’๐‘™ ๐‘๐‘ฆ ๐‘๐‘œ๐‘š๐‘๐‘Ž๐‘Ÿ๐‘–๐‘›๐‘” ๐‘Ž๐‘๐‘ก๐‘ข๐‘Ž๐‘™ ๐‘™๐‘Ž๐‘๐‘’๐‘™๐‘  ๐‘ค๐‘–๐‘กโ„Ž ๐‘๐‘Ÿ๐‘’๐‘‘๐‘–๐‘๐‘ก๐‘’๐‘‘ ๐‘™๐‘Ž๐‘๐‘’๐‘™๐‘ . It breaks down predictions into:

  • ๐“๐ซ๐ฎ๐ž ๐๐จ๐ฌ๐ข๐ญ๐ข๐ฏ๐ž๐ฌ (๐“๐): Correct positive predictions.
  • ๐“๐ซ๐ฎ๐ž ๐๐ž๐ ๐š๐ญ๐ข๐ฏ๐ž๐ฌ (๐“๐): Correct negative predictions.
  • ๐…๐š๐ฅ๐ฌ๐ž ๐๐จ๐ฌ๐ข๐ญ๐ข๐ฏ๐ž๐ฌ (๐…๐): Incorrectly predicted positives.
  • ๐…๐š๐ฅ๐ฌ๐ž ๐๐ž๐ ๐š๐ญ๐ข๐ฏ๐ž๐ฌ (๐…๐): Incorrectly predicted negatives.
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix

# 2. Setup confusion matrix
confmat = ConfusionMatrix(num_classes=len(class_names),
                          task='multiclass')

confmat_tensor = confmat(preds=y_pred_tensor,
                         target=test_data.targets)

# 3. Plot the confusion matrix
fig, ax = plot_confusion_matrix(
    conf_mat=confmat_tensor.numpy(),
    class_names=class_names,
    figsize=(10, 7)
);

๐‡๐จ๐ฐ ๐‚๐š๐ง ๐š ๐‚๐จ๐ง๐Ÿ๐ฎ๐ฌ๐ข๐จ๐ง ๐Œ๐š๐ญ๐ซ๐ข๐ฑ ๐‡๐ž๐ฅ๐ฉ ๐ˆ๐๐ž๐ง๐ญ๐ข๐Ÿ๐ฒ ๐„๐ซ๐ซ๐จ๐ซ ๐‘๐ž๐ ๐ข๐จ๐ง๐ฌ?

The confusion matrix helps you identify error classes:
โ€ขย ย ๐…๐š๐ฅ๐ฌ๐ž ๐๐จ๐ฌ๐ข๐ญ๐ข๐ฏ๐ž๐ฌ (๐…๐): Indicates over-predicting certain classes.
โ€ขย ย ๐…๐š๐ฅ๐ฌ๐ž ๐๐ž๐ ๐š๐ญ๐ข๐ฏ๐ž๐ฌ (๐…๐): Shows under-prediction or missed classifications.

By studying the matrix, you can spot patterns like:
โ€ขย ย Which classes are most often confused with others.
โ€ขย ย Whether certain classes are consistently misclassified.

๐‡๐จ๐ฐ ๐ญ๐จ ๐‘๐ž๐š๐ ๐š ๐‚๐จ๐ง๐Ÿ๐ฎ๐ฌ๐ข๐จ๐ง ๐Œ๐š๐ญ๐ซ๐ข๐ฑ?

  • ๐ƒ๐ข๐š๐ ๐จ๐ง๐š๐ฅ ๐ฏ๐š๐ฅ๐ฎ๐žs: Represent correct predictions (the more, the better!).
  • ๐Ž๐Ÿ๐Ÿ-๐๐ข๐š๐ ๐จ๐ง๐š๐ฅ ๐ฏ๐š๐ฅ๐ฎ๐ž๐ฌ: Represent incorrect predictions or misclassifications.
  • Top-right quadrant โ†’ False Positives (FP).
  • Bottom-left quadrant โ†’ False Negatives (FN).

๐น๐‘œ๐‘Ÿ ๐‘’๐‘ฅ๐‘Ž๐‘š๐‘๐‘™๐‘’, ๐‘Ž โ„Ž๐‘–๐‘”โ„Ž ๐น๐‘ ๐‘๐‘œ๐‘ข๐‘›๐‘ก ๐‘š๐‘–๐‘”โ„Ž๐‘ก ๐‘ ๐‘–๐‘”๐‘›๐‘Ž๐‘™ ๐‘กโ„Ž๐‘’ ๐‘š๐‘œ๐‘‘๐‘’๐‘™ ๐‘–๐‘  ๐‘š๐‘–๐‘ ๐‘ ๐‘–๐‘›๐‘” ๐‘๐‘œ๐‘ ๐‘–๐‘ก๐‘–๐‘ฃ๐‘’ ๐‘๐‘Ž๐‘ ๐‘’๐‘ , ๐‘คโ„Ž๐‘–๐‘โ„Ž is ๐‘๐‘Ÿ๐‘–๐‘ก๐‘–๐‘๐‘Ž๐‘™ for ๐‘“๐‘Ÿ๐‘Ž๐‘ข๐‘‘ ๐‘‘๐‘’๐‘ก๐‘’๐‘๐‘ก๐‘–๐‘œ๐‘› ๐‘œ๐‘Ÿ ๐‘š๐‘’๐‘‘๐‘–๐‘๐‘Ž๐‘™ ๐‘‘๐‘–๐‘Ž๐‘”๐‘›๐‘œ๐‘ ๐‘’๐‘ .

๐“๐ข๐ฉ๐ฌ ๐Ÿ’ก

  • ๐”๐ฌ๐ž ๐ฐ๐ข๐ญ๐ก ๐ข๐ฆ๐›๐š๐ฅ๐š๐ง๐œ๐ž๐ ๐๐š๐ญ๐š๐ฌ๐ž๐ญ๐ฌ: For datasets where some classes have more samples than others, better to use confusion matrix.
  • ๐๐ซ๐ž๐œ๐ข๐ฌ๐ข๐จ๐ง ๐š๐ง๐ ๐‘๐ž๐œ๐š๐ฅ๐ฅ: Combine the confusion matrix with Precision and Recall to understand modelโ€™s performance.
  • ๐“๐ฎ๐ง๐ž ๐ญ๐ก๐ซ๐ž๐ฌ๐ก๐จ๐ฅ๐๐ฌ: If FPs or FNs are too high, consider adjusting your classification thresholds.
  • ๐‹๐จ๐จ๐ค ๐Ÿ๐จ๐ซ ๐ฉ๐š๐ญ๐ญ๐ž๐ซ๐ง๐ฌ: Are certain classes consistently misclassified? Reviewing your data or refining features for those classes.

#MachineLearning #AI #DataScience #DeepLearning #PyTorch #ConfusionMatrix #Classification #ModelTraining