Classification report
Aggregate statistics for the multiclass classification can be conveniently calculated and summarized using the classification_report
function.
classes = [:Class1, :Class2, :Class3]
predicted = rand(classes, 100)
actual = rand(classes, 100)
classification_report(predicted, actual, label_set=classes)
┌──────────┬───────────┬────────┬──────────┬─────────┐
│ │ Precision │ Recall │ F1 score │ Support │
├──────────┼───────────┼────────┼──────────┼─────────┤
│ Class1 │ 0.4444 │ 0.3529 │ 0.3934 │ 34 │
│ Class2 │ 0.2353 │ 0.2857 │ 0.2581 │ 28 │
│ Class3 │ 0.4359 │ 0.4474 │ 0.4416 │ 38 │
│ Micro │ 0.3719 │ 0.3620 │ 0.3644 │ 100 │
│ Macro │ 0.3700 │ 0.3700 │ 0.3700 │ 100 │
│ Weighted │ 0.3826 │ 0.3700 │ 0.3738 │ 100 │
└──────────┴───────────┴────────┴──────────┴─────────┘
It is possible to specify custom metrics and custom aggregation subsets. If it is necessary, for example, to only aggregate specific classes, it can be accomplished with the aggregate_subset
that can be passed along other aggregations:
using ClassificationMetrics:precision
classification_report(predicted,
actual,
label_set=classes,
metrics=[precision, recall, F1_score, Fβ_score(2)],
aggregations=[
aggregate_subset(
macro_aggregation,
name="Classes 1 and 2",
predicate=∈([:Class1, :Class2])),
aggregate_subset(
macro_aggregation,
name="Classes 2 and 3",
predicate=∈([:Class2, :Class3])),
micro_aggregation,
macro_aggregation,
weighted_aggregation
])
┌─────────────────┬───────────┬────────┬──────────┬──────────┬─────────┐
│ │ Precision │ Recall │ F1 score │ F2 score │ Support │
├─────────────────┼───────────┼────────┼──────────┼──────────┼─────────┤
│ Class1 │ 0.4444 │ 0.3529 │ 0.3934 │ 0.3681 │ 34 │
│ Class2 │ 0.2353 │ 0.2857 │ 0.2581 │ 0.2740 │ 28 │
│ Class3 │ 0.4359 │ 0.4474 │ 0.4416 │ 0.4450 │ 38 │
│ Classes 1 and 2 │ 0.3279 │ 0.3226 │ 0.3252 │ 0.3236 │ 62 │
│ Classes 2 and 3 │ 0.3425 │ 0.3788 │ 0.3597 │ 0.3709 │ 66 │
│ Micro │ 0.3719 │ 0.3620 │ 0.3644 │ 0.3624 │ 100 │
│ Macro │ 0.3700 │ 0.3700 │ 0.3700 │ 0.3700 │ 100 │
│ Weighted │ 0.3826 │ 0.3700 │ 0.3738 │ 0.3710 │ 100 │
└─────────────────┴───────────┴────────┴──────────┴──────────┴─────────┘
The table is printed using PrettyTables
and accepts keyword arguments defined in that package. For example, to print the table in the $\LaTeX$ or Markdown formats, one can specify the backend
:
using Markdown
markdown_table = classification_report(predicted,
actual,
label_set=classes,
metrics=[precision, recall, F1_score, Fβ_score(2)],
aggregations=[
aggregate_subset(
macro_aggregation,
name="Classes 1 and 2",
predicate=∈([:Class1, :Class2])),
aggregate_subset(
macro_aggregation,
name="Classes 2 and 3",
predicate=∈([:Class2, :Class3])),
micro_aggregation,
macro_aggregation,
weighted_aggregation
],
io=String,
backend=Val(:markdown)) |> first
Precision | Recall | F1 score | F2 score | Support | |
---|---|---|---|---|---|
Class1 | 0.4444 | 0.3529 | 0.3934 | 0.3681 | 34 |
Class2 | 0.2353 | 0.2857 | 0.2581 | 0.2740 | 28 |
Class3 | 0.4359 | 0.4474 | 0.4416 | 0.4450 | 38 |
Classes 1 and 2 | 0.3279 | 0.3226 | 0.3252 | 0.3236 | 62 |
Classes 2 and 3 | 0.3425 | 0.3788 | 0.3597 | 0.3709 | 66 |
Micro | 0.3719 | 0.3620 | 0.3644 | 0.3624 | 100 |
Macro | 0.3700 | 0.3700 | 0.3700 | 0.3700 | 100 |
Weighted | 0.3826 | 0.3700 | 0.3738 | 0.3710 | 100 |
latex_table = classification_report(predicted,
actual,
label_set=classes,
metrics=[precision, recall, F1_score, Fβ_score(2)],
aggregations=[
aggregate_subset(
macro_aggregation,
name="Classes 1 and 2",
predicate=∈([:Class1, :Class2])),
aggregate_subset(
macro_aggregation,
name="Classes 2 and 3",
predicate=∈([:Class2, :Class3])),
micro_aggregation,
macro_aggregation,
weighted_aggregation
],
io=String,
backend=Val(:latex)) |> first
This is LuaHBTeX, Version 1.18.0 (TeX Live 2024)
restricted system commands enabled.
luaotfload | db : Font names database not found, generating new one.
luaotfload | db : This can take several minutes; please be patient.processing PDF file
graphic size: 459.680361pt x 163.199713pt (161.559169mm x 57.358139mm)
output written to table.svg
1 of 1 page converted in 0.133627 seconds