{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Recall Benchmark \n", "\n", "This notebooks measures the recall of k-MSTs computed with NN-Descent compared\n", "to the ground-truth k-MSTs. Several parameters are varied to investigate how the\n", "algorithm behaves. In particular, we vary the dataset, number of neighbors, and\n", "number of desent neighbors. The latter variable indicates how many neighbors are\n", "used in the NN Descent stage. When it is higher than the number of neighbors\n", "required for a $k$-NN network to be a single connected componenent, then normal\n", "NN Descent should find all MST edges, and the performance of the MST-descent\n", "stage is not measured well. Throughout the parameter sweep, we measure the\n", "number of neighbors required in a dataset for a $k$-NN to be a single connected\n", "component. In addition, we measure the recall and distance fraction for the\n", "global output, each boruvka iteration, and each descent iteration. The global\n", "distance fraction is computed over all edges. The Boruvka and Descent distance\n", "fraction only looks at the ground-truth edges of each boruvka iteration.\n", "\n", "The main questions we want to answer are:\n", "\n", "- How accurate is our NN-Descent for constructing $k$-MSTs?\n", "- How accurate is our NN-Descent in finding shortest edges between connected components?\n", "- How do the parameters influence NN-Descent convergence?\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "metadata": {} }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm\n", "from scipy.sparse.csgraph import connected_components\n", "from sklearn.datasets import load_diabetes, load_iris, load_digits, load_wine, fetch_openml\n", "from sklearn.preprocessing import RobustScaler\n", "\n", "from umap import UMAP\n", "from multi_mst import KMST, KMSTDescentLogRecall\n", "from lib.drawing import draw_umap" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Datasets\n", "\n", "The cells in this section load and pre-process the datasets (where neccesary)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "metadata": {} }, "outputs": [], "source": [ "data = {}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SKLearn Diabetes" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "metadata": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\jelme\\Documents\\Development\\work\\multi_mst\\multi_mst\\notebooks\\lib\\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored\n", " s = plt.scatter(xs, ys, c=color, s=size, edgecolors=\"none\", linewidth=0, cmap='viridis', alpha=alpha)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X, y = load_diabetes(return_X_y=True)\n", "p = UMAP(n_neighbors=5).fit(X)\n", "draw_umap(p)\n", "data['diabetes'] = X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SKLearn Iris" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "metadata": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\jelme\\Documents\\Development\\work\\multi_mst\\multi_mst\\notebooks\\lib\\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored\n", " s = plt.scatter(xs, ys, c=color, s=size, edgecolors=\"none\", linewidth=0, cmap='viridis', alpha=alpha)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACYCAYAAAC29DagAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAS7ElEQVR4nO3de1xUdf7H8RfDIAioiAhyCQVEQblpKgpeQEPX9KHLblu7Wqs/N7e2rDarVTNb81G5lpW2a+uutPuo1czWMkkR5OINYVERRNRoAiSVvAxXYRiG4czvD5OHrm6FDh5kPs+/YA6c8zkM7/lezs3OYrFYEMLGadQuQIjOQIIgBBIEIQAJghCABEEIQIIgBCBBEAKQIAgBSBCEACQIQgASBCEACYIQgARBCECCIAQgQWi3tUkbMba0ql2GsDI7uR7hx2tpVQiOGk15UR52dnZqlyOsSFqEdlDMLXg6ayQEXZAEoR2amprQ2sufrCuSd7UddDodzs7OapchOoAEoR1ycnLw8/NTuwzRASQI7XCk4BjePhKErkirdgF3C7PZjF5/idrqKrVLER1AWoQfoaWlhWmJD6J1dOHL8jOYzWa1SxJWJscRfoQHHniAZnMrwyLC+ebMGapr60ne9qnaZQkrkq7RD3j++efp378/kydPJj4+nt3p6TzxzHMYjUacnJzULk9YiXSNvsfq1avp27cvbn36Eh8fT7du3Zg+bRpeXl58si1Z7fKEFUkQ/oeFi5fyWeoeEhISKD55im7dugFgNBoZHHAP1RcqVa5QWJME4X/IzTuMv0cPkpKSiBoagslkAmDz5s2MGzeO5uZmlSsU1iRBuAmTycSMyfFERERSVFLOoEGDKCoqAqCsrIzQ0FCVKxTWJkG4iZMnTxIUFMTAgUEMGzKQ8+fPU1VVxRtvr+HQsRM0NTVxoaaeusuNapcqrERmjW7i0qVLGAwGRowYQVxcHJPun8GI8FDKy8sZNCScnj17Mi1hIj1cuqtdqrASaRFuQlEUzpw5Q1hYGJ6enkyIiaaxsZEhQ4bw7ttvkpKSgqIoaDTy5+sq5J28ieLiYsLCwtq+HxYZgUmxY/fBIyxbtgwPDw+Sd+xUsUJhbRKE/9JgMFLx7SXq6upQFAUAB0cntI5OJH/8ISEhIRQWFpJ++BSKIgfluwo5xeK/5ObmEhoaSnJyMt7e3gD8Y+MW1r75Op6enhgMBtLS0hgTO5Z+nn1VrlZYiwyW/0tDQwNubm74+voyadIkDh8+zKvLX8LT05PS0lJKSkqYOXOmjA+6GHk3v0dtbS319fUEBQwgOzubmpoa7r//fglBFyTv6DXOnz9Pnz592r7ft28fY8aMYfv27QQGBjJixAgVqxMdSYJwjRMnThAREQFcOajWr18/MjIymDp1Kj4+PipXJzqSBOEa2YcK0Gg06PV6dDodiqIwY8aMthPuRNclg+VrHMjazZYBfuzK3MvkuLGMGTNG7ZLEHSItwndqa2upMSps2/45f3r1lbapU2EbJAjfyc07hHvv3mzetAmffl5qlyPuMAnCd3Z+kcwH69dib2+vdilCBRKE7xgMBpkZsmESBECv18vMkI2TIAAff/wxEydOVLsMoSIJAlcOnt13331qlyFUZPNBKC0txdHREXd3d7VLESqy+SAcOnSI4OBgtcsQKrPpICiKwunTpwkKClK7FKEymw5Cfn4+/v7+ODo6ql2KUJlNB6GyshKj0UhISIjapQiV2WwQjEYjX5eW4eXlRb9+/a5bdvVaZWE7bDYIWVlZFH9VhoODww3LTCbTTV8XXZdNBkFRFD769Aue+M2vKS0tvWG5yWSScYON6ZAgnL9wgRVvrMHQZOyI1d+2f236CF8vD4YNG0ZDQ8MNy41Go7QINsZqF+Zs/vc2jh7KISf/GN9e1NO3lwsH92YQN2E8i//wQqd5SPe2bdsoLDjK2LHjKSkpwdvb+4a71kmLYHuscl+jE1+VMv+xJ4gKDaK7iyu+Pj4MHRKKXq9nV3oWdoqZwIAB+Pv7k5iYqNpR3OTkZM6dO8eAAQPo1q0biqIwdOhQKisrr7swv7S0FEVR5ECbDbFKizB0UBD33OPLzJkz+demzTy3cCE+3ldmYmbPno2iKKSnp9Pc3MwHH3yA0Wjk0LFTuDjaExk2hJaWFmpqaqiovEifXi70dnNjwKAhzHv4l2i11mm0tm/fzp49e7hUZ+CSvoqI8DBcXFzw8fHh+PHj1/2syWTCxcXFKtsVdwerjRFe/+NScnJyiYoIbwuB2WymoaGB6upqBg4ciNFopPLCRfTVNQQPuIddael8vmMXAwcOJCQkhOUvLeHNN95g1qxZNBsu88i8+ax4beVtTWeazWYWLV7CM4uXYTAYaG1ppujEKd7883qOntQBoNForttGc3OznJZtY6x6y8cFi/9ID60F9149gCv/YA4ODjg4ONC9e3d69+6Na89eaDQazKZm6uvrmfPE75kYM4rP/r3lhk9/RVF4//33yc7O5p133ml3lyozM5O3/voPLlddZFR0NOe+PU/YoECef27hdQ8CvHjxIhUVFYwcORKAw4cPExwcjJub2+39QcTdw6Ky8xcvWR586JeWUaNGWdLS0m76M8XFxZYpMx6w7N69+wfXV1NTY/nnhxsti1/6o+Xzzz+3NBmb25a1tLRYFi1aZGltbb3h91JTU9u+PnDggKWxsfEW9kbcrTrNTYB37NjBqlWrGD58ONOnTyc2NhZnZ+e25clf7GT37lT6eXmxePHittajtraWwsJCWlpaAKioqGDgwIFMmDDhpjNVJSUl7Ny5k4ULF173enp6OvHx8Wi1WrKyshg/frzVxiei8+s0QYArd5pbvXo1zj1707ePOzg44WxvwQLkFpxkZPggjhw+RFV9I88seJKers707NmTyMhINBoNqampREVF4e/v/73b2bBhAxEREURHR7e9ptfrKS0tJTo6mszMTCZNmtTBeys6k071kTd06FBee+01/vyXdXh7efLIIw+3tQql5acpOXWSifFxfHv+AnW11fw8cSYAZ8+eJT8/n6lTp/6oQe78+fN58cUXGTZsWNvPe3h4kJ+f32H7Jjq3TneKhY+PD0tfXMKZM9+wfv36tse6BgUM4P777ycoKAgnx26UlpbyySefkJeXx+nTp5k5c2a7ZnoWLFjA22+/fd1rGo0Gs9ls1f0Rd4dO1TW6ltlsZuXKK1OnS5cuvaG/Xl9fz6y581m08GnGjY29pW1s2rQJLy+vtuuVq6urKSkpwWAwSNfIxnS6FuEqrVbLsmXLcHNzY8mSJTccS+jZsydPzp/LnqzMW97G7Nmz2bdvX9v5Ru7u7tTX199W3eLu1GmDcNUzzzxDWFgYCxYsuCEMqVn7cXV15Ztvvrmt9V/bRdJqtdI9skGdPggAc+bMISEhgXnz5mGxWDCbzaxZswZLi5GsrCx++/jvuHTp0i2t28PDg8jISLZt2wbAvffey5dffmnN8sVdoFPNGn2fxMREnJxd+e1TC6mrqSI8dDCjRtxLTZWeAt0Ztn76GQP6++Pl5UVUVFS7Hu80c+ZMXn31VWJjY/H09KSpqakD90R0RndFi3DV1CkJvPLSEnr16kXqngN8+umnWCwWhgT68YsHfk5LSwsWi4W0tDTS09Nveq3B/7JgwQLWrVuHoijY29u3zVYJ29BpZ41+rMrKSlJ2Z+Lj2YdRo0ah0+kwGAzExsZy6NAhGhsbCQwMZPDgwT+4royMDC5cuICrqyt9+vRh7Nixd2APRGdw1wfhKkVRyM3Npa6ujrCwMAoLCwkPDycgIACdTkdZWRmOjo7ExMR87/GG1atX4+npiZeXF1OmTLmDeyDU1GWCcJWiKGRnZ9PY2Iirq2vbMQGtVktDQwO5ubmYzWbCw8Px8/O74fcNBgNzn3qBx2b9jHHjxsnp2DaiywXhKrPZzP79+7l8+TJms5mwsLC27pGiKBQWFnLhwgXc3d0ZOXLkdYPrOfPms/btNykuLpbukY3oskG4ymQysW/fPsrLy+nbty/Tpk277lNer9e3nWMUHR2Nm5sbb731FqGhodjb20v3yEZ0+SBcZTQaycjI4Pjx40yZMoXhw4dft/zqGKO+vp6cQ0fZvS+blS+9QExMzHUX8YiuyWaCcFVDQwObN29Gr9fz1FNP4erqet3ynJwcVq9eTWjoEPbt38fvnljA7F89pFK14k6xuSBcpdfr+ct763Hu0Yt7I4YAoNPp+PuHm3n04YcIDg7m9JlzbPjwYw7tTVW5WtHR7qoDatbk4eHB8pdfYvaDP6OpqQk7jT3p6enkZKUxbdo0FEXB39ebXg6tZO7Zq3a5ooPdNadYdBRfX198fX2ZNfdRHvn1HLRaLf3796d///4AXG40suy11UyKj1O1TtGxbLZFuFZDQwNRQwYxOeE+UlJS2u6HWllZybffnsPJSY4ldHU2O0a4VlJSEvHx8W1Pzjl69Ch7DxzkcN5/mPqTKQQMCmHc6FEqVyk6ks13jQDOnz9PQEAAe/bsYdeuXVy8eJEebu5s2LABVxfnH16BuOvZfIuw4vVV7M7aR1RIIJ6ensyYMYOoqCi1yxJ3mE23CI8//Syfbd/J39e+yahRI/Hx8VG7JKESmw2CxWKh3tjK6VOF191ITNgmm+8aCQEyfSoEIEEQApAgCAFIEIQAJAhCABIEIQAJghCABEEIQIIgBCBBEAKQIAgBSBCEACQIQgASBCEACYIQgARBCECCIAQgQRACkCAIAUgQhAAkCEIAEgQhAAmCEIAEQQhAgiBuwZdfn+bhx55RuwyrkiCIdvtg00ds/Xgj7657T+1SrEZu+SjaraGhgVWr3uAv73+IrugIHh4eapd022z2JsDi1rm6uhITMwbFzp4JcXEczc/H0dFR7bJui3SNxC2LGxfDQ7+azcNzf6N2KbdNgiBumdls5re/+T+8+rixZs0atcu5LRIE0W61tbX06NEDNzc3qqqqePTRRykrK2Pr1q1ql3bLJAii3fLy8hg9ejTR0dEUFBSg1WpZvnw5eXl57NixA0VR1C6x3SQIol0URUFRFLRaLRqNBjs7O8xmM+7u7kyfPp28/EKSk7/AbG7t0BoKi4rZn5NntXXKrJFol6NHjxIZGdn2/YgRI8jNzaVHjx4cP36cs2fPcvHCebIOHOQPzz6Nn5+f1bb9yec76e1y5ZnX/v378/tFS5n9yweZP2f2ba9bWgTRLqe+Lm976GJtbS3//uxzUjOyOHz4MDExMfxzw3qm/WQyrYqFFxYtZuvWrRiNxtvbZomOx558mnLdKRISEhgzZgzFx4/TTWtPNyv9B0uLINolOzePlW+t5acJ4zl37hwBAYEorWa0Wi3Dhw8HIDw8nK+++oqRkUM5fvw4p06dYty4ccTFxbVrW42GJp5ftIStO3YzOmoIAwfcQ2ZmJk5OThQUFBAREsRDv/i5VfZLjiyLdluw8A8UF59Ea9fK4OAgqi830dvFkffeu3LKRXJyMqNHj6aoqIiKigp69eqFyWSitLSUWbNmERQU9L3rr66uJi8vj12pqbi5e+Bgr+Hc2TM4OjqycuVKtmzZgoODAyaTiXnz5lllnyQI4pYpikJ6ejrHjh1j86fbCQsZxNMLnsDb2xs/Pz+Kioqoq6ujurqalpYWJk6cyMaNG2lububJJ5+84bG+Op2Or7/+Gjc3N9zd3fniiy+or69nxYoV5OXlsfyVVwgIDGRAQBAevXsRGxvL4MGDrbIvEgRhFQaDgUWLF/OfgmI2Jv217R80MzMTrVZL79690el0TJo0CZPJxPr16/Hz82P2w4+Qk5uL0dBIYGAggwcPxmg0kpSURGVlJY8//jj+/v4ArH5nLX/b8D6REeEMCgrg9ddetVr9MlgWVuHs7Myf332XfWk7+NOqVTz13CJMJhPx8fFUV1dz9uxZJk2aRHZ2NiaTiZdffpmIiAiipySSk/sfpk6d2haetLQ0NBoN3t7ebSEwGo0MDRnEKy8tRqfTUdvUYtX6pUUQHeLAwRz+9tf38PLyIj4+HkVRaGlpITExkb179+Lr64tOp0Or1ZKamkpoaCiPPfYYubm5NDc3s2XLFtatW4dGc+WzOiUlhcmTJ/Piiy9y5GgBaampOGjtrVavBEF0qCNHjpCSkkJdXR1VVVXExcUxd+5c8vLyWLbidTb+cwOenp48++yz9O3bl/Hjx5OSkkJ8fDwJCQkAlJeXU1NTw8GDB9m0ZSurVq1iQuxoq9YpQRAdzmg0kpGRgV6vJz8/Hzs7O1xcXKiqqeNYQT4xMTF4eHiwLSWDyJBA7OzsSEpKAq4MyDdt2kRZWRl5eXnMnz+fxMREq9coQRB3TElJCf/a+BEGk4kzFRVoLK0YDAbKTlfgf48/FsXMiRMncHBy5qEHH6SHS3cy9ufQ1FCPd79+zJv7a6ZPn94htUkQxB1laDJytLCQ+tpa6mpr0Ov1FBWfYHfWARobL6OgIWJwIJcvX8auW3dCggMJHhhMo6mVlcuXYm9vvXHBtSQIotOoOFuJp0cfujtdudqtpVXBwf7OTGxKEIRAjiMIAUgQhAAkCEIAEgQhAAmCEIAEQQhAgiAEIEEQApAgCAFIEIQAJAhCABIEIQAJghCABEEIAP4f3XbgyRRprXIAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X, y = load_iris(return_X_y=True)\n", "p = UMAP(n_neighbors=5).fit(X)\n", "draw_umap(p)\n", "data['iris'] = X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SKLearn Digits" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "metadata": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\jelme\\Documents\\Development\\work\\multi_mst\\multi_mst\\notebooks\\lib\\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored\n", " s = plt.scatter(xs, ys, c=color, s=size, edgecolors=\"none\", linewidth=0, cmap='viridis', alpha=alpha)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X, y = load_digits(return_X_y=True)\n", "p = UMAP(n_neighbors=5).fit(X)\n", "draw_umap(p)\n", "data['digits'] = X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SKLearn Wine" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "metadata": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\jelme\\Documents\\Development\\work\\multi_mst\\multi_mst\\notebooks\\lib\\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored\n", " s = plt.scatter(xs, ys, c=color, s=size, edgecolors=\"none\", linewidth=0, cmap='viridis', alpha=alpha)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X, y = load_wine(return_X_y=True)\n", "X = RobustScaler().fit_transform(X)\n", "p = UMAP(n_neighbors=5).fit(X)\n", "draw_umap(p)\n", "data['wine'] = X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Horse" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "metadata": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\jelme\\Documents\\Development\\work\\multi_mst\\multi_mst\\notebooks\\lib\\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored\n", " s = plt.scatter(xs, ys, c=color, s=size, edgecolors=\"none\", linewidth=0, cmap='viridis', alpha=alpha)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X = pd.read_csv('data/horse/horse.csv').to_numpy()\n", "p = UMAP(n_neighbors=20).fit(X)\n", "draw_umap(p)\n", "data['horse'] = X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MNIST" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "metadata": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\jelme\\Documents\\Development\\work\\multi_mst\\multi_mst\\notebooks\\lib\\drawing.py:8: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored\n", " s = plt.scatter(xs, ys, c=color, s=size, edgecolors=\"none\", linewidth=0, cmap='viridis', alpha=alpha)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X, target = fetch_openml(\"mnist_784\", version=1, return_X_y=True)\n", "p = UMAP().fit(X)\n", "draw_umap(p)\n", "data['mnist'] = X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parameter Sweep\n", "\n", "The cell below evaluates all parameters multiple times for each dataset and collects the measurements in a single dataframe." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "metadata": {} }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 6/6 [13:47:22<00:00, 8273.83s/it] \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datasetboruvka_num_componentsboruvka_recalldescent_recalldescent_distance_fractiondescent_num_changesboruvka_distance_fractionglobal_recallglobal_precisionglobal_dist_fracconnected_knum_observationsnum_dimensionsmin_descent_neighborsnum_neighborsrepeat
0diabetes[92, 17, 3][1.0, 1.0, 1.0][[0.91847825, 0.9470109, 0.95516306, 0.9578804...[[1.0, 1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan,...[[3020.0, 323.0, 36.0, 11.0, 2.0, nan, nan, na...[1.0, 1.0, 1.0]0.994220.9961390.996875344210620
1diabetes[93, 17, 3][1.0, 1.0, 1.0][[0.95694447, 0.9583333, 0.9583333, 0.9583333,...[[1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan, nan,...[[4950.0, 88.0, 16.0, 2.0, nan, nan, nan, nan,...[1.0, 1.0, 1.0]1.000001.0000001.0000003442101220
2diabetes[93, 17, 3][1.0, 1.0, 1.0][[0.96056336, 0.96056336, nan, nan, nan, nan, ...[[1.0, 1.0, nan, nan, nan, nan, nan, nan, nan,...[[6806.0, 7.0, nan, nan, nan, nan, nan, nan, n...[1.0, 1.0, 1.0]1.000001.0000001.0000003442102420
3diabetes[93, 17, 3][1.0, 1.0, 1.0][[0.96153843, 0.96153843, nan, nan, nan, nan, ...[[1.0, 1.0, nan, nan, nan, nan, nan, nan, nan,...[[7889.0, 0.0, nan, nan, nan, nan, nan, nan, n...[1.0, 1.0, 1.0]1.000001.0000001.0000003442103620
4diabetes[93, 17, 3][1.0, 1.0, 1.0][[0.9046961, 0.9378453, 0.95027626, 0.9530387,...[[1.0006415, 1.0006415, 1.0, 1.0, 1.0, 1.0, 1....[[2998.0, 345.0, 40.0, 10.0, 5.0, 4.0, 0.0, na...[1.0, 1.0, 1.0]0.994220.9942201.000280344210621
\n", "
" ], "text/plain": [ " dataset boruvka_num_components boruvka_recall \\\n", "0 diabetes [92, 17, 3] [1.0, 1.0, 1.0] \n", "1 diabetes [93, 17, 3] [1.0, 1.0, 1.0] \n", "2 diabetes [93, 17, 3] [1.0, 1.0, 1.0] \n", "3 diabetes [93, 17, 3] [1.0, 1.0, 1.0] \n", "4 diabetes [93, 17, 3] [1.0, 1.0, 1.0] \n", "\n", " descent_recall \\\n", "0 [[0.91847825, 0.9470109, 0.95516306, 0.9578804... \n", "1 [[0.95694447, 0.9583333, 0.9583333, 0.9583333,... \n", "2 [[0.96056336, 0.96056336, nan, nan, nan, nan, ... \n", "3 [[0.96153843, 0.96153843, nan, nan, nan, nan, ... \n", "4 [[0.9046961, 0.9378453, 0.95027626, 0.9530387,... \n", "\n", " descent_distance_fraction \\\n", "0 [[1.0, 1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan,... \n", "1 [[1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan, nan,... \n", "2 [[1.0, 1.0, nan, nan, nan, nan, nan, nan, nan,... \n", "3 [[1.0, 1.0, nan, nan, nan, nan, nan, nan, nan,... \n", "4 [[1.0006415, 1.0006415, 1.0, 1.0, 1.0, 1.0, 1.... \n", "\n", " descent_num_changes \\\n", "0 [[3020.0, 323.0, 36.0, 11.0, 2.0, nan, nan, na... \n", "1 [[4950.0, 88.0, 16.0, 2.0, nan, nan, nan, nan,... \n", "2 [[6806.0, 7.0, nan, nan, nan, nan, nan, nan, n... \n", "3 [[7889.0, 0.0, nan, nan, nan, nan, nan, nan, n... \n", "4 [[2998.0, 345.0, 40.0, 10.0, 5.0, 4.0, 0.0, na... \n", "\n", " boruvka_distance_fraction global_recall global_precision \\\n", "0 [1.0, 1.0, 1.0] 0.99422 0.996139 \n", "1 [1.0, 1.0, 1.0] 1.00000 1.000000 \n", "2 [1.0, 1.0, 1.0] 1.00000 1.000000 \n", "3 [1.0, 1.0, 1.0] 1.00000 1.000000 \n", "4 [1.0, 1.0, 1.0] 0.99422 0.994220 \n", "\n", " global_dist_frac connected_k num_observations num_dimensions \\\n", "0 0.996875 3 442 10 \n", "1 1.000000 3 442 10 \n", "2 1.000000 3 442 10 \n", "3 1.000000 3 442 10 \n", "4 1.000280 3 442 10 \n", "\n", " min_descent_neighbors num_neighbors repeat \n", "0 6 2 0 \n", "1 12 2 0 \n", "2 24 2 0 \n", "3 36 2 0 \n", "4 6 2 1 " ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "repeats = 5\n", "neighbours = [2, 3, 4, 6]\n", "min_descent_neighbours = [6, 12, 24, 36]\n", "results = []\n", "for dataset_name in tqdm(data.keys()):\n", " X = data[dataset_name]\n", "\n", " connected_k = 2\n", " while True:\n", " p = UMAP(n_neighbors=connected_k, transform_mode=\"graph\").fit(X)\n", " if connected_components(p.graph_, directed=False, return_labels=False) == 1:\n", " break\n", " connected_k += 1\n", " data_size = X.shape[0]\n", " data_dims = X.shape[1]\n", "\n", " for k in neighbours:\n", " # Compute ground truth\n", " kmst = KMST(num_neighbors=k).fit(X)\n", " m1 = kmst.umap(transform_mode=\"graph\").graph_.copy()\n", " m1.data[:] = 1\n", "\n", " for repeat in range(repeats):\n", " for n in min_descent_neighbours:\n", " if n is None:\n", " n = k\n", " elif n < k:\n", " continue\n", " # Compute Descent kMST\n", " dmst = KMSTDescentLogRecall(\n", " num_neighbors=k,\n", " min_descent_neighbors=n,\n", " ).fit(X)\n", " m2 = dmst.umap(transform_mode=\"graph\").graph_.copy()\n", " m2.data[:] = 1\n", "\n", " # Extract trace measures\n", " true_positive = m1.multiply(m2).nnz\n", " if len(dmst.trace_) == 0:\n", " measures = pd.DataFrame(\n", " {\n", " \"dataset\": [dataset_name],\n", " \"boruvka_num_components\": [[]],\n", " \"descent_distance_fraction\": [[]],\n", " \"boruvka_recall\": [[]],\n", " \"boruvka_distance_fraction\": [[]],\n", " \"descent_num_changes\": [[]],\n", " \"descent_recall\": [[]],\n", " }\n", " )\n", " else:\n", " measures = pd.DataFrame(dmst.trace_)\n", " measures[\"boruvka_distance_fraction\"] = measures[\n", " \"descent_distance_fraction\"\n", " ].apply(lambda x: x[np.argmax(np.isnan(x)) - 1])\n", " # Convert to one row with lists\n", " measures[\"dataset\"] = dataset_name\n", " measures = measures.groupby(\"dataset\").agg(list).reset_index()\n", " # Add per-run measures\n", " measures[\"global_recall\"] = true_positive / m1.nnz\n", " measures[\"global_precision\"] = true_positive / m2.nnz\n", " measures[\"global_dist_frac\"] = (\n", " dmst.graph_.data.sum() / kmst.graph_.data.sum()\n", " )\n", " measures[\"connected_k\"] = connected_k\n", " measures[\"num_observations\"] = data_size\n", " measures[\"num_dimensions\"] = data_dims\n", " measures[\"min_descent_neighbors\"] = n\n", " measures[\"num_neighbors\"] = k\n", " measures[\"repeat\"] = repeat\n", " results.append(measures)\n", "\n", "results = pd.concat(results, ignore_index=True)\n", "results.head()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "metadata": {} }, "outputs": [], "source": [ "results.to_parquet(\"./data/generated/recall_benchmark.parquet\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results\n", "\n", "This section creates plots showing our results for each of our questions:\n", "\n", "- How accurate is our NN-Descent for constructing $k$-MSTs?\n", "- How accurate is our NN-Descent in finding shortest edges between connected components?\n", "- How do the parameters influence NN-Descent's performance?" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from lib.plotting import *\n", "configure_matplotlib()\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", category=FutureWarning)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "results = pd.read_parquet(\"./data/generated/recall_benchmark.parquet\")\n", "min_descent_neighbours = results.min_descent_neighbors.unique()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datasetnum_observationsnum_dimensionsconnected_k
0diabetes442103
1digits1797648
2horse8431315
3iris150426
4mnist700007844
5wine178134
\n", "
" ], "text/plain": [ " dataset num_observations num_dimensions connected_k\n", "0 diabetes 442 10 3\n", "1 digits 1797 64 8\n", "2 horse 8431 3 15\n", "3 iris 150 4 26\n", "4 mnist 70000 784 4\n", "5 wine 178 13 4" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results.groupby('dataset').agg({'num_observations': 'first', 'num_dimensions': 'first', 'connected_k': 'first'}).reset_index()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "results[\"num_boruvka_iters\"] = results.boruvka_recall.apply(len)\n", "results[\"boruvka_iteration\"] = results.num_boruvka_iters.apply(lambda x: list(range(x)))\n", "\n", "vars = [\n", " \"dataset\",\n", " \"repeat\",\n", " \"connected_k\",\n", " \"num_neighbors\",\n", " \"num_observations\",\n", " \"min_descent_neighbors\",\n", "]\n", "measures = [\n", " \"boruvka_iteration\",\n", " \"boruvka_num_components\",\n", " \"boruvka_recall\",\n", " \"boruvka_distance_fraction\",\n", " \"descent_num_changes\",\n", " \"descent_recall\",\n", " \"descent_distance_fraction\",\n", "]\n", "exploded = results[vars + measures].explode(measures).reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "exploded['num_descent_iters'] = exploded.descent_recall.apply(lambda x: len(x) if hasattr(x, '__len__') else 0)\n", "exploded['descent_converged'] = exploded.num_descent_iters != exploded.descent_recall.apply(lambda x: np.argmax(np.isnan(x)))\n", "exploded['descent_iteration'] = exploded.num_descent_iters.apply(lambda x: list(range(x)))\n", "\n", "vars = ['dataset', 'repeat', 'connected_k', 'num_neighbors', 'min_descent_neighbors', 'boruvka_num_components', 'boruvka_iteration', 'descent_converged']\n", "measures = ['descent_iteration', 'descent_num_changes', 'descent_recall', 'descent_distance_fraction']\n", "twice_exploded = exploded[vars + measures].explode(measures).reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "exploded['descent_final_recall'] = exploded.descent_recall.apply(lambda x: x[np.argmax(np.isnan(x)) - 1] if hasattr(x, '__len__') else np.nan)\n", "exploded['descent_final_dist_frac'] = exploded.descent_distance_fraction.apply(lambda x: x[np.argmax(np.isnan(x)) - 1] if hasattr(x, '__len__') else np.nan)\n", "exploded['descent_final_changes'] = exploded.descent_num_changes.apply(lambda x: x[np.argmax(np.isnan(x)) - 1] if hasattr(x, '__len__') else np.nan)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How accurate is our NN-Descent for constructing $k$-MSTs?\n", "\n", "The globall recall was high (>0.94) in all cases, which is a promising result.\n", "There are two concerns:\n", "\n", "1. Most edges in the $k$-MST are from the $k$-NN, so global recall mostly\n", " reflects how well $k$-NN are found, rather than the MST edges.\n", "2. The global recall was worse at low $k$ for a difficult dataset (with high\n", " connected_k). This is a dataset that requires the MST stage to find the\n", " appropriate edges, as they are not included in the nearest neighbors. So, a\n", " worse performance at low $k$ indicates our approach did not find the nearest\n", " neighbors. The increasing recall at higher $k$ indicates that the higher\n", " neighbours were detected." ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "dataset_order = [\"iris\", \"wine\", \"diabetes\", \"digits\", \"horse\", \"mnist\"]\n", "display_name = dict(\n", " iris=\"Iris\",\n", " wine=\"Wine\",\n", " diabetes=\"Diabetes\",\n", " digits=\"Digits\",\n", " horse=\"Horse\",\n", " mnist=\"MNIST\",\n", ")" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
num_neighbors2346
dataset
iris0.950.950.970.95
wine1.001.001.001.00
diabetes1.001.000.991.00
digits0.990.990.990.99
horse1.001.001.001.00
mnist0.990.970.970.97
\n", "
" ], "text/plain": [ "num_neighbors 2 3 4 6\n", "dataset \n", "iris 0.95 0.95 0.97 0.95\n", "wine 1.00 1.00 1.00 1.00\n", "diabetes 1.00 1.00 0.99 1.00\n", "digits 0.99 0.99 0.99 0.99\n", "horse 1.00 1.00 1.00 1.00\n", "mnist 0.99 0.97 0.97 0.97" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results.groupby([\"dataset\", \"num_neighbors\"]).global_recall.mean().unstack().reindex(\n", " dataset_order\n", ").round(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another way to measure the quality of our approximate $k$-MST is by comparing\n", "the total distance over its edges to the ground-truth $k$-MST. The figure below\n", "show the approximate total distance divided by the true total distance. An\n", "optimal solution has a value of $1$. Higher values are worse, lower values\n", "happen when non-exact approximate edges connect components not yet connected by\n", "ground truth edges. Again the most difficult dataset is most different from $1$\n", "lower $k$, indicating we did not find the exact MST edges." ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
num_neighbors2346
dataset
iris0.990.991.01.0
wine1.001.001.01.0
diabetes1.001.001.01.0
digits1.001.001.01.0
horse1.001.001.01.0
mnist1.001.001.01.0
\n", "
" ], "text/plain": [ "num_neighbors 2 3 4 6\n", "dataset \n", "iris 0.99 0.99 1.0 1.0\n", "wine 1.00 1.00 1.0 1.0\n", "diabetes 1.00 1.00 1.0 1.0\n", "digits 1.00 1.00 1.0 1.0\n", "horse 1.00 1.00 1.0 1.0\n", "mnist 1.00 1.00 1.0 1.0" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results.groupby([\"dataset\", \"num_neighbors\"]).global_dist_frac.mean().unstack().reindex(\n", " dataset_order\n", ").round(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How accurate is our NN-Descent?\n", "\n", "Instead of looking at the global $k$-MST, lets zoom in to the Boruvka algorithm\n", "and see how well we found the edges we are looking for. Here we see that \n", "the Iris dataset gave the lowest recall. This is also the dataset with the \n", "highest connecting $k$, meaning that the MST stage is actually required!" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sized_fig(1/2)\n", "sns.pointplot(\n", " exploded,\n", " x=\"boruvka_iteration\",\n", " y=\"boruvka_recall\",\n", " hue='dataset',\n", " hue_order=dataset_order,\n", " ci=95,\n", " units=\"repeat\",\n", " markers=['o', 's', 'd', 'x', 'v', 'p'],\n", " linewidth=0.5,\n", " linestyle=':',\n", " markersize=3,\n", " palette=\"tab10\",\n", " native_scale=True,\n", " legend=False,\n", ")\n", "plt.ylim([0, 1.05])\n", "plt.ylabel('Bor\\\\r{u}vka recall')\n", "plt.xlabel('Bor\\\\r{u}vka iteration')\n", "plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)\n", "plt.savefig('images/boruvka_recall_vs_iterations.pdf')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
boruvka_iteration012345
dataset
iris0.8050890.6690390.894097NaNNaNNaN
wine0.9991931.01.01.0NaNNaN
diabetes0.9992731.01.0NaNNaNNaN
digits0.9715630.9679370.959130.986042NaNNaN
horse0.9971530.9904960.9922830.997090.9922281.0
mnist0.9828460.9892490.9957780.9873120.993333NaN
\n", "
" ], "text/plain": [ "boruvka_iteration 0 1 2 3 4 5\n", "dataset \n", "iris 0.805089 0.669039 0.894097 NaN NaN NaN\n", "wine 0.999193 1.0 1.0 1.0 NaN NaN\n", "diabetes 0.999273 1.0 1.0 NaN NaN NaN\n", "digits 0.971563 0.967937 0.95913 0.986042 NaN NaN\n", "horse 0.997153 0.990496 0.992283 0.99709 0.992228 1.0\n", "mnist 0.982846 0.989249 0.995778 0.987312 0.993333 NaN" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exploded.groupby(['dataset', 'boruvka_iteration']).boruvka_recall.mean().unstack().reindex(dataset_order)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The recall in the descent stage tells a different story, indicating that finding\n", "connecting edges for points that are not part of the shortest edge between\n", "components is more difficult." ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sized_fig(1/2)\n", "display = exploded.copy()\n", "display['dataset'].replace(display_name, inplace=True)\n", "sns.pointplot(\n", " display,\n", " x=\"boruvka_iteration\",\n", " y=\"descent_final_recall\",\n", " hue='dataset',\n", " hue_order=[display_name[l] for l in dataset_order],\n", " ci=95,\n", " units=\"repeat\",\n", " markers=['o', 's', 'd', 'x', 'v', 'p'],\n", " linewidth=0.5,\n", " linestyle=':',\n", " markersize=3,\n", " palette=\"tab10\",\n", " native_scale=True\n", ")\n", "plt.legend(title='')\n", "plt.ylim([0, 1.05])\n", "plt.ylabel('Descent recall')\n", "plt.xlabel('Bor\\\\r{u}vka iteration')\n", "plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)\n", "plt.savefig('images/descent_recall_vs_iteration.pdf')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How do the parameters influence NN-Descent's performance?" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sized_fig(1/2)\n", "sns.pointplot(\n", " exploded,\n", " x=\"num_neighbors\",\n", " y=\"boruvka_recall\",\n", " hue='dataset',\n", " hue_order=dataset_order,\n", " ci=95,\n", " units=\"repeat\",\n", " markers=['o', 's', 'd', 'x', 'v', 'p'],\n", " linewidth=0.5,\n", " linestyle=':',\n", " markersize=3,\n", " palette=\"tab10\",\n", " native_scale=True,\n", " legend=False,\n", ")\n", "plt.ylim([0, 1.05])\n", "plt.ylabel('Bor\\\\r{u}vka recall')\n", "plt.xlabel('Num. neighbors ($k$)')\n", "plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)\n", "plt.savefig('images/boruvka_recall_vs_neighbors.pdf')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sized_fig(1/2)\n", "sns.pointplot(\n", " exploded,\n", " x=\"min_descent_neighbors\",\n", " y=\"boruvka_distance_fraction\",\n", " hue='dataset',\n", " hue_order=dataset_order,\n", " ci=95,\n", " units=\"repeat\",\n", " markers=['o', 's', 'd', 'x', 'v', 'p'],\n", " linewidth=0.5,\n", " linestyle=':',\n", " markersize=3,\n", " palette=\"tab10\",\n", " legend=False,\n", " native_scale=True\n", ")\n", "plt.ylim([0.98, 1.1])\n", "plt.yticks([1, 1.05, 1.1])\n", "plt.xticks(min_descent_neighbours)\n", "plt.ylabel('Distance fraction')\n", "plt.xlabel('Descent neighbors ($k_{descent}$)')\n", "plt.subplots_adjust(0.2, 0.24, 0.95, 0.95)\n", "plt.savefig('images/boruvka_dist_fract_vs_descent_neighbors.pdf', pad_inches=0)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How did the descent stage converge?" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = sized_fig(1/2)\n", "for n, k_d in enumerate(twice_exploded.min_descent_neighbors.unique()):\n", " for i, it in enumerate(twice_exploded.boruvka_iteration.unique()):\n", " if np.isnan(i):\n", " continue\n", " for dataset in twice_exploded.dataset.unique():\n", " for k in twice_exploded.num_neighbors.unique():\n", " for r in twice_exploded.repeat.unique():\n", " d = twice_exploded.query(\n", " f\"dataset == '{dataset}' and num_neighbors == {k} and min_descent_neighbors == {k_d} and boruvka_iteration == {it} and repeat == {r}\"\n", " )\n", " plt.plot(\n", " d.descent_iteration,\n", " d.descent_recall,\n", " linewidth=0.2,\n", " alpha=0.3,\n", " color=\"k\",\n", " )\n", " # plt.vlines(x=exploded.descent_recall.apply(len).max(), ymin=0, ymax=1, color=\"r\", linestyle=\"--\")\n", " plt.ylim([0, 1])\n", " # plt.xlim([0, 50])\n", " # plt.xticks([0, 25, 50])\n", "plt.ylabel('Descent recall')\n", "plt.xlabel('Descent iteration') \n", "plt.subplots_adjust(0.16, 0.24, 0.95, 0.95)\n", "plt.savefig('images/descent_convergence.pdf', pad_inches=0)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "work", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }