Classification report

Aggregate statistics for the multiclass classification can be conveniently calculated and summarized using the classification_report function.

classes = [:Class1, :Class2, :Class3]

predicted = rand(classes, 100)
actual = rand(classes, 100)

classification_report(predicted, actual, label_set=classes)
┌──────────┬───────────┬────────┬──────────┬─────────┐
│           Precision  Recall  F1 score  Support │
├──────────┼───────────┼────────┼──────────┼─────────┤
│   Class1 │    0.4444 │ 0.3529 │   0.3934 │      34 │
│   Class2 │    0.2353 │ 0.2857 │   0.2581 │      28 │
│   Class3 │    0.4359 │ 0.4474 │   0.4416 │      38 │
│    Micro │    0.3719 │ 0.3620 │   0.3644 │     100 │
│    Macro │    0.3700 │ 0.3700 │   0.3700 │     100 │
│ Weighted │    0.3826 │ 0.3700 │   0.3738 │     100 │
└──────────┴───────────┴────────┴──────────┴─────────┘

It is possible to specify custom metrics and custom aggregation subsets. If it is necessary, for example, to only aggregate specific classes, it can be accomplished with the aggregate_subset that can be passed along other aggregations:

using ClassificationMetrics:precision

classification_report(predicted,
                        actual,
                        label_set=classes,
                        metrics=[precision, recall, F1_score, Fβ_score(2)],
                        aggregations=[
                            aggregate_subset(
                                macro_aggregation,
                                name="Classes 1 and 2",
                                predicate=∈([:Class1, :Class2])),
                            aggregate_subset(
                                macro_aggregation,
                                name="Classes 2 and 3",
                                predicate=∈([:Class2, :Class3])),
                            micro_aggregation,
                            macro_aggregation,
                            weighted_aggregation
                        ])
┌─────────────────┬───────────┬────────┬──────────┬──────────┬─────────┐
│                  Precision  Recall  F1 score  F2 score  Support │
├─────────────────┼───────────┼────────┼──────────┼──────────┼─────────┤
│          Class1 │    0.4444 │ 0.3529 │   0.3934 │   0.3681 │      34 │
│          Class2 │    0.2353 │ 0.2857 │   0.2581 │   0.2740 │      28 │
│          Class3 │    0.4359 │ 0.4474 │   0.4416 │   0.4450 │      38 │
│ Classes 1 and 2 │    0.3279 │ 0.3226 │   0.3252 │   0.3236 │      62 │
│ Classes 2 and 3 │    0.3425 │ 0.3788 │   0.3597 │   0.3709 │      66 │
│           Micro │    0.3719 │ 0.3620 │   0.3644 │   0.3624 │     100 │
│           Macro │    0.3700 │ 0.3700 │   0.3700 │   0.3700 │     100 │
│        Weighted │    0.3826 │ 0.3700 │   0.3738 │   0.3710 │     100 │
└─────────────────┴───────────┴────────┴──────────┴──────────┴─────────┘

The table is printed using PrettyTables and accepts keyword arguments defined in that package. For example, to print the table in the $\LaTeX$ or Markdown formats, one can specify the backend:

using Markdown

markdown_table = classification_report(predicted,
                        actual,
                        label_set=classes,
                        metrics=[precision, recall, F1_score, Fβ_score(2)],
                        aggregations=[
                            aggregate_subset(
                                macro_aggregation,
                                name="Classes 1 and 2",
                                predicate=∈([:Class1, :Class2])),
                            aggregate_subset(
                                macro_aggregation,
                                name="Classes 2 and 3",
                                predicate=∈([:Class2, :Class3])),
                            micro_aggregation,
                            macro_aggregation,
                            weighted_aggregation
                        ],
                        io=String,
                        backend=Val(:markdown)) |> first
PrecisionRecallF1 scoreF2 scoreSupport
Class10.44440.35290.39340.368134
Class20.23530.28570.25810.274028
Class30.43590.44740.44160.445038
Classes 1 and 20.32790.32260.32520.323662
Classes 2 and 30.34250.37880.35970.370966
Micro0.37190.36200.36440.3624100
Macro0.37000.37000.37000.3700100
Weighted0.38260.37000.37380.3710100
latex_table = classification_report(predicted,
                        actual,
                        label_set=classes,
                        metrics=[precision, recall, F1_score, Fβ_score(2)],
                        aggregations=[
                            aggregate_subset(
                                macro_aggregation,
                                name="Classes 1 and 2",
                                predicate=∈([:Class1, :Class2])),
                            aggregate_subset(
                                macro_aggregation,
                                name="Classes 2 and 3",
                                predicate=∈([:Class2, :Class3])),
                            micro_aggregation,
                            macro_aggregation,
                            weighted_aggregation
                        ],
                        io=String,
                        backend=Val(:latex)) |> first
This is LuaHBTeX, Version 1.18.0 (TeX Live 2024)
 restricted system commands enabled.

luaotfload | db : Font names database not found, generating new one.
luaotfload | db : This can take several minutes; please be patient.processing PDF file
  graphic size: 459.680361pt x 163.199713pt (161.559169mm x 57.358139mm)
  output written to table.svg
1 of 1 page converted in 0.133627 seconds