{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Runtime Benchmark (MNIST)\n", "\n", "This notebook compares compute costs of UMAP (=$k$-NN), the exact $k$-MST (=KDTree-based boruvka) and approximate $k$-MST (=NNDescent-based boruvka) algorithms. The dataset samples and generated graphs are stored for re-analysis and visualization. On MNIST, the approximate $k$-MST is roughly two orders of magnitude faster than the exact $k$-MST algorithm!" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from tqdm import tqdm\n", "from scipy.sparse import save_npz\n", "\n", "from sklearn.datasets import fetch_openml\n", "from sklearn.utils.random import sample_without_replacement\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "\n", "from umap import UMAP\n", "from multi_mst import KMST, KMSTDescent" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "jupyter": { "source_hidden": true } }, "outputs": [], "source": [ "# Trigger numba compilation\n", "_ = KMSTDescent().fit(np.random.rand(100, 2))\n", "_ = KMST().fit(np.random.rand(100, 2))\n", "_ = UMAP(force_approximation_algorithm=True, transform_mode=\"graph\").fit_transform(\n", " np.random.rand(100, 2)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## Timed algorithms\n", "\n", "Implement parameter sweep, output logging, and timing code." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "jupyter": { "source_hidden": true } }, "outputs": [], "source": [ "def time_task(task, *args, **kwargs):\n", " \"\"\"Outputs compute time in seconds.\"\"\"\n", " start_time = time.perf_counter()\n", " result = task(*args, **kwargs)\n", " end_time = time.perf_counter()\n", " return end_time - start_time, result" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "jupyter": { "source_hidden": true } }, "outputs": [], "source": [ "def run_dmst(data, n_neighbors):\n", " mst = KMSTDescent(num_neighbors=n_neighbors)\n", " compute_time, umap = time_task(lambda: mst.fit(data).umap(transform_mode=\"graph\"))\n", " return compute_time, umap.graph_\n", "\n", "\n", "def run_kmst(data, n_neighbors):\n", " mst = KMST(num_neighbors=n_neighbors)\n", " compute_time, umap = time_task(lambda: mst.fit(data).umap(transform_mode=\"graph\"))\n", " return compute_time, umap.graph_\n", "\n", "\n", "def run_umap(data, n_neighbors):\n", " umap = UMAP(n_neighbors=n_neighbors, transform_mode=\"graph\")\n", " compute_time, umap = time_task(lambda: umap.fit(data))\n", " return compute_time, umap.graph_\n", "\n", "\n", "mains = {\"dmst\": run_dmst, \"kmst\": run_kmst, \"umap\": run_umap}\n", "\n", "\n", "def run(data, algorithm, n_neighbors):\n", " return mains[algorithm](data, n_neighbors)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "jupyter": { "source_hidden": true } }, "outputs": [], "source": [ "def compute_and_evaluate_setting(\n", " data,\n", " algorithm=\"dmst\",\n", " repeat=0,\n", " frac=1.0,\n", " n_neighbors=5,\n", "):\n", " compute_time, graph = run(data, algorithm, n_neighbors)\n", " save_npz(\n", " f\"./data/generated/mnist/graph_{algorithm}_{n_neighbors}_{frac}_{repeat}.npz\",\n", " graph.tocoo(),\n", " )\n", " return (\n", " algorithm,\n", " frac,\n", " n_neighbors,\n", " repeat,\n", " graph.nnz,\n", " compute_time,\n", " )" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "jupyter": { "source_hidden": true } }, "outputs": [], "source": [ "def init_file(path):\n", " handle = open(path, \"w\", buffering=1)\n", " handle.write(\n", " \"algorithm,sample_fraction,n_neighbors,repeat,num_edges,compute_time\\n\"\n", " )\n", " return handle\n", "\n", "\n", "def write_line(handle, *args):\n", " handle.write(\",\".join([str(v) for v in args]) + \"\\n\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "repeats = 5\n", "algorithms = ['dmst', 'umap', 'kmst']\n", "fraction = np.exp(np.linspace(np.log(0.1), np.log(1), 5)).round(2)\n", "n_neighbors = [2, 3, 6]\n", "\n", "total = len(algorithms) * len(fraction) * len(n_neighbors)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 7000., 12600., 22400., 39200., 70000.])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fraction * 70000" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(70000, 784)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df, target = fetch_openml(\"mnist_784\", version=1, return_X_y=True)\n", "df.shape" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [5:04:38<00:00, 3655.75s/it] \n" ] } ], "source": [ "output = init_file(\"./data/generated/mnist/metrics.csv\")\n", "for repeat in tqdm(range(repeats)):\n", " pbar = tqdm(desc=\"Compute\", total=total, leave=False)\n", " for frac in fraction:\n", " sample_idx = sample_without_replacement(df.shape[0], int(df.shape[0] * frac))\n", " np.save(f\"./data/generated/mnist/sampled_indices_{frac}_{repeat}.npy\", sample_idx)\n", " X = df.iloc[sample_idx, :]\n", " for algorithm in algorithms:\n", " for k in n_neighbors:\n", " result = compute_and_evaluate_setting(X, algorithm, repeat, frac, k)\n", " write_line(output, *result)\n", " pbar.update()\n", " pbar.close()\n", "output.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results\n", "\n", "NNDescent-based $k$-MST is more expensive that NNDescent based $k$-NN. Scaling appears a bit steeper but still usable. Definately a lot quicker than KDTree-based MSTs!" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import seaborn as sns\n", "import matplotlib.lines as ml\n", "from lib.plotting import *\n", "\n", "configure_matplotlib()\n", "\n", "import warnings\n", "warnings.simplefilter(action=\"ignore\", category=FutureWarning)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "values = pd.read_csv(\"./data/generated/mnist/metrics.csv\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib.ticker import FixedLocator\n", "\n", "# Fit robust linear regression in log-log space\n", "fig = sized_fig(1/2)\n", "ax = plt.gca()\n", "for i, alg in enumerate([\"kmst\", \"dmst\", \"umap\"]):\n", " alg_values = values[values[\"algorithm\"] == alg]\n", " sns.regplot(\n", " x=np.log10(alg_values[\"sample_fraction\"] * df.shape[0]),\n", " y=np.log10(alg_values[\"compute_time\"]),\n", " ci=95,\n", " order=1,\n", " robust=True,\n", " color=f\"C{i}\",\n", " units=alg_values[\"repeat\"],\n", " scatter_kws={\"edgecolor\": \"none\", \"linewidths\": 0, \"s\": 2},\n", " line_kws={\"linewidth\": 1},\n", " ax=ax,\n", " )\n", "ax.set_xlabel(\"Num. MNIST points\")\n", "ax.set_ylabel(\"Run time (s)\")\n", "\n", "# Draw log y-ticks\n", "y_ticks = np.array([0.0, 1.0, 2.0, 3.0])\n", "plt.ylim(-1, plt.ylim()[1])\n", "ax.set_yticks(y_ticks)\n", "ax.get_yaxis().set_major_formatter(lambda x, pos: f\"$10^{{{int(x)}}}$\")\n", "ax.get_yaxis().set_minor_locator(\n", " FixedLocator(locs=np.concat((\n", " np.log10(np.arange(0.2, 1, 0.1) * 10.0 ** y_ticks[0]),\n", " np.log10(np.arange(2, 10) * 10.0 ** y_ticks[None].T).ravel())\n", " ))\n", ")\n", "\n", "# Draw log x-ticks\n", "x_ticks = np.array([4.0])\n", "ax.set_xticks(x_ticks)\n", "ax.get_xaxis().set_major_formatter(lambda x, pos: f\"$10^{{{int(x)}}}$\")\n", "ax.get_xaxis().set_minor_locator(\n", " FixedLocator(locs=np.log10(np.array(\n", " [0.6, 0.7, 0.8, 0.9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 20, 30, 40, 50, 60, 70, 80]\n", " ) * 10.0 ** x_ticks[None].T).ravel())\n", ")\n", "plt.xlim(np.log10([6000, 80000]))\n", "\n", "# Legend\n", "adjust_legend_subtitles(\n", " plt.legend(\n", " loc=\"upper left\",\n", " handles=[\n", " ml.Line2D([], [], color=f\"C{j}\", label=f\"{v}\")\n", " for j, v in enumerate(['$k$-MST (kd-tree)', '$k$-MST (descent)', '$k$-NN'])\n", " ]\n", " )\n", ")\n", "plt.subplots_adjust(left=0.17, right=0.9, top=0.95, bottom=0.24)\n", "plt.savefig(\"./images/mnist_scaling.pdf\", pad_inches=0)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "work", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 4 }