Recall Benchmark

This notebooks measures the recall of k-MSTs computed with NN-Descent compared to the ground-truth k-MSTs. Several parameters are varied to investigate how the algorithm behaves. In particular, we vary the dataset, number of neighbors, and number of desent neighbors. The latter variable indicates how many neighbors are used in the NN Descent stage. When it is higher than the number of neighbors required for a \(k\)-NN network to be a single connected componenent, then normal NN Descent should find all MST edges, and the performance of the MST-descent stage is not measured well. Throughout the parameter sweep, we measure the number of neighbors required in a dataset for a \(k\)-NN to be a single connected component. In addition, we measure the recall and distance fraction for the global output, each boruvka iteration, and each descent iteration. The global distance fraction is computed over all edges. The Boruvka and Descent distance fraction only looks at the ground-truth edges of each boruvka iteration.

The main questions we want to answer are:

  • How accurate is our NN-Descent for constructing \(k\)-MSTs?

  • How accurate is our NN-Descent in finding shortest edges between connected components?

  • How do the parameters influence NN-Descent convergence?

[2]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.sparse.csgraph import connected_components
from sklearn.datasets import load_diabetes, load_iris, load_digits, load_wine, fetch_openml
from sklearn.preprocessing import RobustScaler

from umap import UMAP
from multi_mst import KMST, KMSTDescentLogRecall
from lib.drawing import draw_umap

Datasets

The cells in this section load and pre-process the datasets (where neccesary).

[3]:
data = {}

SKLearn Diabetes

[4]:
X, y = load_diabetes(return_X_y=True)
p = UMAP(n_neighbors=5).fit(X)
draw_umap(p)
data['diabetes'] = X
c:\Users\jelme\Documents\Development\work\multi_mst\multi_mst\notebooks\lib\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  s = plt.scatter(xs, ys, c=color, s=size, edgecolors="none", linewidth=0, cmap='viridis', alpha=alpha)
_images/Benchmark_recall_5_1.png

SKLearn Iris

[5]:
X, y = load_iris(return_X_y=True)
p = UMAP(n_neighbors=5).fit(X)
draw_umap(p)
data['iris'] = X
c:\Users\jelme\Documents\Development\work\multi_mst\multi_mst\notebooks\lib\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  s = plt.scatter(xs, ys, c=color, s=size, edgecolors="none", linewidth=0, cmap='viridis', alpha=alpha)
_images/Benchmark_recall_7_1.png

SKLearn Digits

[6]:
X, y = load_digits(return_X_y=True)
p = UMAP(n_neighbors=5).fit(X)
draw_umap(p)
data['digits'] = X
c:\Users\jelme\Documents\Development\work\multi_mst\multi_mst\notebooks\lib\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  s = plt.scatter(xs, ys, c=color, s=size, edgecolors="none", linewidth=0, cmap='viridis', alpha=alpha)
_images/Benchmark_recall_9_1.png

SKLearn Wine

[7]:
X, y = load_wine(return_X_y=True)
X = RobustScaler().fit_transform(X)
p = UMAP(n_neighbors=5).fit(X)
draw_umap(p)
data['wine'] = X
c:\Users\jelme\Documents\Development\work\multi_mst\multi_mst\notebooks\lib\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  s = plt.scatter(xs, ys, c=color, s=size, edgecolors="none", linewidth=0, cmap='viridis', alpha=alpha)
_images/Benchmark_recall_11_1.png

Horse

[8]:
X = pd.read_csv('data/horse/horse.csv').to_numpy()
p = UMAP(n_neighbors=20).fit(X)
draw_umap(p)
data['horse'] = X
c:\Users\jelme\Documents\Development\work\multi_mst\multi_mst\notebooks\lib\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  s = plt.scatter(xs, ys, c=color, s=size, edgecolors="none", linewidth=0, cmap='viridis', alpha=alpha)
_images/Benchmark_recall_13_1.png

MNIST

[9]:
X, target = fetch_openml("mnist_784", version=1, return_X_y=True)
p = UMAP().fit(X)
draw_umap(p)
data['mnist'] = X
c:\Users\jelme\Documents\Development\work\multi_mst\multi_mst\notebooks\lib\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  s = plt.scatter(xs, ys, c=color, s=size, edgecolors="none", linewidth=0, cmap='viridis', alpha=alpha)
_images/Benchmark_recall_15_1.png

Parameter Sweep

The cell below evaluates all parameters multiple times for each dataset and collects the measurements in a single dataframe.

[13]:
repeats = 5
neighbours = [2, 3, 4, 6]
min_descent_neighbours = [6, 12, 24, 36]
results = []
for dataset_name in tqdm(data.keys()):
    X = data[dataset_name]

    connected_k = 2
    while True:
        p = UMAP(n_neighbors=connected_k, transform_mode="graph").fit(X)
        if connected_components(p.graph_, directed=False, return_labels=False) == 1:
            break
        connected_k += 1
    data_size = X.shape[0]
    data_dims = X.shape[1]

    for k in neighbours:
        # Compute ground truth
        kmst = KMST(num_neighbors=k).fit(X).umap(transform_mode="graph")
        m1 = kmst.graph_.copy()
        m1.data[:] = 1

        for repeat in range(repeats):
            for n in min_descent_neighbours:
                if n is None:
                    n = k
                elif n < k:
                    continue
                # Compute Descent kMST
                dmst = KMSTDescentLogRecall(
                    num_neighbors=k,
                    min_descent_neighbors=n,
                ).fit(X)
                m2 = dmst.umap(transform_mode="graph").graph_.copy()
                m2.data[:] = 1

                # Extract trace measures
                true_positive = m1.multiply(m2).nnz
                if len(dmst.trace_) == 0:
                    measures = pd.DataFrame(
                        {
                            "dataset": [dataset_name],
                            "boruvka_num_components": [[]],
                            "descent_distance_fraction": [[]],
                            "boruvka_recall": [[]],
                            "boruvka_distance_fraction": [[]],
                            "descent_num_changes": [[]],
                            "descent_recall": [[]],
                        }
                    )
                else:
                    measures = pd.DataFrame(dmst.trace_)
                    measures["boruvka_distance_fraction"] = measures[
                        "descent_distance_fraction"
                    ].apply(lambda x: x[np.argmax(np.isnan(x)) - 1])
                    # Convert to one row with lists
                    measures["dataset"] = dataset_name
                    measures = measures.groupby("dataset").agg(list).reset_index()
                # Add per-run measures
                measures["global_recall"] = true_positive / m1.nnz
                measures["global_precision"] = true_positive / m2.nnz
                measures["global_dist_frac"] = (
                    dmst.graph_.data.sum() / kmst.graph_.data.sum()
                )
                measures["connected_k"] = connected_k
                measures["num_observations"] = data_size
                measures["num_dimensions"] = data_dims
                measures["min_descent_neighbors"] = n
                measures["num_neighbors"] = k
                measures["repeat"] = repeat
                results.append(measures)

results = pd.concat(results, ignore_index=True)
results.head()
100%|██████████| 6/6 [13:35:13<00:00, 8152.23s/it]
[13]:
dataset boruvka_num_components boruvka_recall descent_recall descent_distance_fraction descent_num_changes boruvka_distance_fraction global_recall global_precision global_dist_frac connected_k num_observations num_dimensions min_descent_neighbors num_neighbors repeat
0 diabetes [93, 17, 3] [1.0, 1.0, 1.0] [[0.9111111, 0.94027776, 0.94861114, 0.95, 0.9... [[1.0, 1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan,... [[3072.0, 314.0, 54.0, 8.0, 2.0, nan, nan, nan... [1.0, 1.0, 1.0] 0.986513 0.986513 0.043145 3 442 10 6 2 0
1 diabetes [93, 17, 3] [1.0, 1.0, 1.0] [[0.96067417, 0.96348315, 0.96348315, 0.963483... [[1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan, nan,... [[4965.0, 87.0, 6.0, 1.0, nan, nan, nan, nan, ... [1.0, 1.0, 1.0] 1.000000 1.000000 0.043106 3 442 10 12 2 0
2 diabetes [93, 17, 3] [1.0, 1.0, 1.0] [[0.95694447, 0.95694447, nan, nan, nan, nan, ... [[1.0, 1.0, nan, nan, nan, nan, nan, nan, nan,... [[6886.0, 2.0, nan, nan, nan, nan, nan, nan, n... [1.0, 1.0, 1.0] 1.000000 1.000000 0.043106 3 442 10 24 2 0
3 diabetes [93, 17, 3] [1.0, 1.0, 1.0] [[0.95111734, 0.95111734, nan, nan, nan, nan, ... [[1.0, 1.0, nan, nan, nan, nan, nan, nan, nan,... [[7852.0, 0.0, nan, nan, nan, nan, nan, nan, n... [1.0, 1.0, 1.0] 1.000000 1.000000 0.043106 3 442 10 36 2 0
4 diabetes [93, 17, 3] [1.0, 1.0, 1.0] [[0.9044321, 0.9362881, 0.94598335, 0.94598335... [[1.0006415, 1.0006415, 1.0, 1.0, 1.0, 1.0, na... [[3038.0, 352.0, 51.0, 13.0, 3.0, 1.0, nan, na... [1.0, 1.0, 1.0] 0.992293 0.992293 0.043115 3 442 10 6 2 1
[14]:
results.to_parquet("./data/generated/recall_benchmark.parquet")

Results

This section creates plots showing our results for each of our questions:

  • How accurate is our NN-Descent for constructing \(k\)-MSTs?

  • How accurate is our NN-Descent in finding shortest edges between connected components?

  • How do the parameters influence NN-Descent’s performance?

[16]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from lib.plotting import *
configure_matplotlib()

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
[17]:
results = pd.read_parquet("./data/generated/recall_benchmark.parquet")
min_descent_neighbours = results.min_descent_neighbors.unique()
[18]:
results.groupby('dataset').agg({'num_observations': 'first', 'num_dimensions': 'first', 'connected_k': 'first'}).reset_index()
[18]:
dataset num_observations num_dimensions connected_k
0 diabetes 442 10 3
1 digits 1797 64 8
2 horse 8431 3 18
3 iris 150 4 26
4 mnist 70000 784 5
5 wine 178 13 4
[19]:
results["num_boruvka_iters"] = results.boruvka_recall.apply(len)
results["boruvka_iteration"] = results.num_boruvka_iters.apply(lambda x: list(range(x)))

vars = [
    "dataset",
    "repeat",
    "connected_k",
    "num_neighbors",
    "num_observations",
    "min_descent_neighbors",
]
measures = [
    "boruvka_iteration",
    "boruvka_num_components",
    "boruvka_recall",
    "boruvka_distance_fraction",
    "descent_num_changes",
    "descent_recall",
    "descent_distance_fraction",
]
exploded = results[vars + measures].explode(measures).reset_index(drop=True)
[20]:
exploded['num_descent_iters'] = exploded.descent_recall.apply(lambda x: len(x) if hasattr(x, '__len__') else 0)
exploded['descent_converged'] = exploded.num_descent_iters != exploded.descent_recall.apply(lambda x: np.argmax(np.isnan(x)))
exploded['descent_iteration'] = exploded.num_descent_iters.apply(lambda x: list(range(x)))

vars = ['dataset', 'repeat', 'connected_k', 'num_neighbors', 'min_descent_neighbors', 'boruvka_num_components', 'boruvka_iteration', 'descent_converged']
measures = ['descent_iteration', 'descent_num_changes', 'descent_recall', 'descent_distance_fraction']
twice_exploded = exploded[vars + measures].explode(measures).reset_index(drop=True)
[21]:
exploded['descent_final_recall'] = exploded.descent_recall.apply(lambda x: x[np.argmax(np.isnan(x)) - 1] if hasattr(x, '__len__') else np.nan)
exploded['descent_final_dist_frac'] = exploded.descent_distance_fraction.apply(lambda x: x[np.argmax(np.isnan(x)) - 1] if hasattr(x, '__len__') else np.nan)
exploded['descent_final_changes'] = exploded.descent_num_changes.apply(lambda x: x[np.argmax(np.isnan(x)) - 1] if hasattr(x, '__len__') else np.nan)

How accurate is our NN-Descent for constructing \(k\)-MSTs?

The globall recall was high (>0.94) in all cases, which is a promising result. There are two concerns:

  1. Most edges in the \(k\)-MST are from the \(k\)-NN, so global recall mostly reflects how well \(k\)-NN are found, rather than the MST edges.

  2. The global recall was worse at low \(k\) for a difficult dataset (with high connected_k). This is a dataset that requires the MST stage to find the appropriate edges, as they are not included in the nearest neighbors. So, a worse performance at low \(k\) indicates our approach did not find the nearest neighbors. The increasing recall at higher \(k\) indicates that the higher neighbours were detected.

[22]:
dataset_order = ["iris", "wine", "diabetes", "digits", "horse", "mnist"]
display_name = dict(
    iris="Iris",
    wine="Wine",
    diabetes="Diabetes",
    digits="Digits",
    horse="Horse",
    mnist="MNIST",
)
[23]:
results.groupby(["dataset", "num_neighbors"]).global_recall.mean().unstack().reindex(
    dataset_order
).round(2)
[23]:
num_neighbors 2 3 4 6
dataset
iris 0.95 0.95 0.97 0.95
wine 1.00 1.00 1.00 1.00
diabetes 1.00 1.00 0.99 1.00
digits 0.99 0.99 0.99 0.99
horse 1.00 1.00 1.00 1.00
mnist 0.99 0.97 0.97 0.97

Another way to measure the quality of our approximate \(k\)-MST is by comparing the total distance over its edges to the ground-truth \(k\)-MST. The figure below show the approximate total distance divided by the true total distance. An optimal solution has a value of \(1\). Higher values are worse, lower values happen when non-exact approximate edges connect components not yet connected by ground truth edges. Again the most difficult dataset is most different from \(1\) lower \(k\), indicating we did not find the exact MST edges.

[24]:
results.groupby(["dataset", "num_neighbors"]).global_dist_frac.mean().unstack().reindex(
    dataset_order
).round(2)
[24]:
num_neighbors 2 3 4 6
dataset
iris 0.19000 0.220000 0.24000 0.270000
wine 0.90000 0.990000 1.05000 1.090000
diabetes 0.04000 0.050000 0.05000 0.060000
digits 10.89000 11.800000 12.42000 13.300000
horse 0.00000 0.010000 0.01000 0.010000
mnist 639.22998 678.619995 702.76001 736.119995

How accurate is our NN-Descent?

Instead of looking at the global \(k\)-MST, lets zoom in to the Boruvka algorithm and see how well we found the edges we are looking for. Here we see that the Iris dataset gave the lowest recall. This is also the dataset with the highest connecting \(k\), meaning that the MST stage is actually required!

[25]:
sized_fig(1/2)
sns.pointplot(
    exploded,
    x="boruvka_iteration",
    y="boruvka_recall",
    hue='dataset',
    hue_order=dataset_order,
    ci=95,
    units="repeat",
    markers=['o', 's', 'd', 'x',  'v', 'p'],
    linewidth=0.5,
    linestyle=':',
    markersize=3,
    palette="tab10",
    native_scale=True,
    legend=False,
)
plt.ylim([0, 1.05])
plt.ylabel('Bor\\r{u}vka recall')
plt.xlabel('Bor\\r{u}vka iteration')
plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)
plt.savefig('images/boruvka_recall_vs_iterations.pdf')
plt.show()
_images/Benchmark_recall_32_0.png
[26]:
exploded.groupby(['dataset', 'boruvka_iteration']).boruvka_recall.mean().unstack().reindex(dataset_order)
[26]:
boruvka_iteration 0 1 2 3 4 5
dataset
iris 0.803427 0.66934 0.893403 NaN NaN NaN
wine 0.999453 1.0 1.0 1.0 NaN NaN
diabetes 0.999115 1.0 0.997917 NaN NaN NaN
digits 0.970977 0.968882 0.958517 0.981042 NaN NaN
horse 0.997149 0.99038 0.992793 0.998381 0.988462 1.0
mnist 0.982824 0.988905 0.995083 0.985923 0.994375 NaN

The recall in the descent stage tells a different story, indicating that finding connecting edges for points that are not part of the shortest edge between components is more difficult.

[27]:
sized_fig(1/2)
display = exploded.copy()
display['dataset'].replace(display_name, inplace=True)
sns.pointplot(
    display,
    x="boruvka_iteration",
    y="descent_final_recall",
    hue='dataset',
    hue_order=[display_name[l] for l in dataset_order],
    ci=95,
    units="repeat",
    markers=['o', 's', 'd', 'x',  'v', 'p'],
    linewidth=0.5,
    linestyle=':',
    markersize=3,
    palette="tab10",
    native_scale=True
)
plt.legend(title='')
plt.ylim([0, 1.05])
plt.ylabel('Descent recall')
plt.xlabel('Bor\\r{u}vka iteration')
plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)
plt.savefig('images/descent_recall_vs_iteration.pdf')
plt.show()
_images/Benchmark_recall_35_0.png

How do the parameters influence NN-Descent’s performance?

[28]:
sized_fig(1/2)
sns.pointplot(
    exploded,
    x="num_neighbors",
    y="boruvka_recall",
    hue='dataset',
    hue_order=dataset_order,
    ci=95,
    units="repeat",
    markers=['o', 's', 'd', 'x',  'v', 'p'],
    linewidth=0.5,
    linestyle=':',
    markersize=3,
    palette="tab10",
    native_scale=True,
    legend=False,
)
plt.ylim([0, 1.05])
plt.ylabel('Bor\\r{u}vka recall')
plt.xlabel('Num. neighbors ($k$)')
plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)
plt.savefig('images/boruvka_recall_vs_neighbors.pdf')
plt.show()
_images/Benchmark_recall_37_0.png
[29]:
sized_fig(1/2)
sns.pointplot(
    exploded,
    x="min_descent_neighbors",
    y="boruvka_distance_fraction",
    hue='dataset',
    hue_order=dataset_order,
    ci=95,
    units="repeat",
    markers=['o', 's', 'd', 'x',  'v', 'p'],
    linewidth=0.5,
    linestyle=':',
    markersize=3,
    palette="tab10",
    legend=False,
    native_scale=True
)
plt.ylim([0.98, 1.1])
plt.yticks([1, 1.05, 1.1])
plt.xticks(min_descent_neighbours)
plt.ylabel('Distance fraction')
plt.xlabel('Descent neighbors ($k_{descent}$)')
plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)
plt.savefig('images/boruvka_dist_fract_vs_descent_neighbors.pdf', pad_inches=0)
plt.show()
_images/Benchmark_recall_38_0.png

How did the descent stage converge?

[30]:
fig = sized_fig(1/2)
for n, k_d in enumerate(twice_exploded.min_descent_neighbors.unique()):
    for i, it in enumerate(twice_exploded.boruvka_iteration.unique()):
        if np.isnan(i):
            continue
        for dataset in twice_exploded.dataset.unique():
            for k in twice_exploded.num_neighbors.unique():
                for r in twice_exploded.repeat.unique():
                    d = twice_exploded.query(
                        f"dataset == '{dataset}' and num_neighbors == {k} and min_descent_neighbors == {k_d} and boruvka_iteration == {it} and repeat == {r}"
                    )
                    plt.plot(
                        d.descent_iteration,
                        d.descent_recall,
                        linewidth=0.2,
                        alpha=0.3,
                        color="k",
                    )
        # plt.vlines(x=exploded.descent_recall.apply(len).max(), ymin=0, ymax=1, color="r", linestyle="--")
        plt.ylim([0, 1])
        # plt.xlim([0, 50])
        # plt.xticks([0, 25, 50])
plt.ylabel('Descent recall')
plt.xlabel('Descent iteration')
plt.subplots_adjust(0.16, 0.24, 0.95, 0.95)
plt.savefig('images/descent_convergence.pdf', pad_inches=0)
plt.show()
_images/Benchmark_recall_40_0.png