{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!-- # Copyright (c) 2024 Graphcore Ltd. All rights reserved. -->\n",
    "\n",
    "# GFloat Basics\n",
    "\n",
    "This notebook shows the use of `decode_float` to explore properties of some float formats.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install packages\n",
    "from pandas import DataFrame\n",
    "import numpy as np\n",
    "\n",
    "from gfloat import decode_float\n",
    "from gfloat.formats import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## List all the values in a format\n",
    "\n",
    "The first example shows how to list all values in a given format.\n",
    "We will choose the [OCP](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1) E5M2 format.\n",
    "\n",
    "The object `format_info_ocp_e5m2` is from the `gfloat.formats` package, and describes the characteristics of that format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FormatInfo(name='ocp_e5m2', k=8, precision=3, emax=15, has_nz=True, has_infs=True, num_high_nans=3, has_subnormals=True, is_signed=True, is_twos_complement=False)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "format_info_ocp_e5m2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We shall use the format to decode all values from 0..255, and gather them in a pandas DataFrame.\n",
    "We see that `decode_float` returns a lot more than just the value - it also splits out the exponent, significand, and sign, and returns the `FloatClass`, which allows us to distinguish normal and subnormal numbers, as well as zero, infinity, and nan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fval</th>\n",
       "      <th>exp</th>\n",
       "      <th>expval</th>\n",
       "      <th>significand</th>\n",
       "      <th>fsignificand</th>\n",
       "      <th>signbit</th>\n",
       "      <th>fclass</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>code</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0</td>\n",
       "      <td>-14</td>\n",
       "      <td>0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>FloatClass.ZERO</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.525879e-05</td>\n",
       "      <td>0</td>\n",
       "      <td>-14</td>\n",
       "      <td>1</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0</td>\n",
       "      <td>FloatClass.SUBNORMAL</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.051758e-05</td>\n",
       "      <td>0</td>\n",
       "      <td>-14</td>\n",
       "      <td>2</td>\n",
       "      <td>0.50</td>\n",
       "      <td>0</td>\n",
       "      <td>FloatClass.SUBNORMAL</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.577637e-05</td>\n",
       "      <td>0</td>\n",
       "      <td>-14</td>\n",
       "      <td>3</td>\n",
       "      <td>0.75</td>\n",
       "      <td>0</td>\n",
       "      <td>FloatClass.SUBNORMAL</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6.103516e-05</td>\n",
       "      <td>1</td>\n",
       "      <td>-14</td>\n",
       "      <td>0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0</td>\n",
       "      <td>FloatClass.NORMAL</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>251</th>\n",
       "      <td>-5.734400e+04</td>\n",
       "      <td>30</td>\n",
       "      <td>15</td>\n",
       "      <td>3</td>\n",
       "      <td>1.75</td>\n",
       "      <td>1</td>\n",
       "      <td>FloatClass.NORMAL</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>252</th>\n",
       "      <td>-inf</td>\n",
       "      <td>31</td>\n",
       "      <td>16</td>\n",
       "      <td>0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1</td>\n",
       "      <td>FloatClass.INFINITE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>253</th>\n",
       "      <td>NaN</td>\n",
       "      <td>31</td>\n",
       "      <td>16</td>\n",
       "      <td>1</td>\n",
       "      <td>1.25</td>\n",
       "      <td>1</td>\n",
       "      <td>FloatClass.NAN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>254</th>\n",
       "      <td>NaN</td>\n",
       "      <td>31</td>\n",
       "      <td>16</td>\n",
       "      <td>2</td>\n",
       "      <td>1.50</td>\n",
       "      <td>1</td>\n",
       "      <td>FloatClass.NAN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>255</th>\n",
       "      <td>NaN</td>\n",
       "      <td>31</td>\n",
       "      <td>16</td>\n",
       "      <td>3</td>\n",
       "      <td>1.75</td>\n",
       "      <td>1</td>\n",
       "      <td>FloatClass.NAN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>256 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "              fval  exp  expval  significand  fsignificand  signbit  \\\n",
       "code                                                                  \n",
       "0     0.000000e+00    0     -14            0          0.00        0   \n",
       "1     1.525879e-05    0     -14            1          0.25        0   \n",
       "2     3.051758e-05    0     -14            2          0.50        0   \n",
       "3     4.577637e-05    0     -14            3          0.75        0   \n",
       "4     6.103516e-05    1     -14            0          1.00        0   \n",
       "...            ...  ...     ...          ...           ...      ...   \n",
       "251  -5.734400e+04   30      15            3          1.75        1   \n",
       "252           -inf   31      16            0          1.00        1   \n",
       "253            NaN   31      16            1          1.25        1   \n",
       "254            NaN   31      16            2          1.50        1   \n",
       "255            NaN   31      16            3          1.75        1   \n",
       "\n",
       "                    fclass  \n",
       "code                        \n",
       "0          FloatClass.ZERO  \n",
       "1     FloatClass.SUBNORMAL  \n",
       "2     FloatClass.SUBNORMAL  \n",
       "3     FloatClass.SUBNORMAL  \n",
       "4        FloatClass.NORMAL  \n",
       "...                    ...  \n",
       "251      FloatClass.NORMAL  \n",
       "252    FloatClass.INFINITE  \n",
       "253         FloatClass.NAN  \n",
       "254         FloatClass.NAN  \n",
       "255         FloatClass.NAN  \n",
       "\n",
       "[256 rows x 7 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fmt = format_info_ocp_e5m2\n",
    "vals = [decode_float(fmt, i) for i in range(256)]\n",
    "DataFrame(vals).set_index(\"code\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Additional format info: special values, min, max, dynamic range\n",
    "\n",
    "In addition, `FormatInfo` can tell us about other characteristics of each format.\n",
    "To reproduce some of the OCP spec's tables 1 and 2:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Format                  ocp_e4m3             ocp_e5m2              p3109_p3\n",
      "Max exponent (emax)     8                    15                    15\n",
      "Exponent bias           7                    15                    16\n",
      "Infinities              0                    2                     2\n",
      "Number of NaNs          2                    6                     1\n",
      "Number of zeros         2                    2                     1\n",
      "Max normal number       448.0                57344.0               49152.0\n",
      "Min normal number       0.015625             6.103515625e-05       3.0517578125e-05\n",
      "Min subnormal number    0.001953125          1.52587890625e-05     7.62939453125e-06\n",
      "Dynamic range (binades) 18                   32                    33\n"
     ]
    }
   ],
   "source": [
    "def compute_dynamic_range(fi):\n",
    "    return np.log2(fi.max / fi.smallest)\n",
    "\n",
    "\n",
    "for prop, probe in (\n",
    "    (\"Format                 \", lambda fi: fi.name.replace(\"format_info_\", \"\")),\n",
    "    (\"Max exponent (emax)    \", lambda fi: fi.emax),\n",
    "    (\"Exponent bias          \", lambda fi: fi.expBias),\n",
    "    (\"Infinities             \", lambda fi: 2 * int(fi.has_infs)),\n",
    "    (\"Number of NaNs         \", lambda fi: fi.num_nans),\n",
    "    (\"Number of zeros        \", lambda fi: int(fi.has_zero) + int(fi.has_nz)),\n",
    "    (\"Max normal number      \", lambda fi: fi.max),\n",
    "    (\"Min normal number      \", lambda fi: fi.smallest_normal),\n",
    "    (\"Min subnormal number   \", lambda fi: fi.smallest_subnormal),\n",
    "    (\"Dynamic range (binades)\", lambda x: round(compute_dynamic_range(x))),\n",
    "):\n",
    "    print(\n",
    "        f\"{prop} {probe(format_info_ocp_e4m3):<20} {probe(format_info_ocp_e5m2):<20}  {probe(format_info_p3109(3))}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How do subnormals affect dynamic range?\n",
    "\n",
    "Most, if not all, low-precision formats include subnormal numbers, as they increase the number of values near zero, and increase dynamic range.\n",
    "A natural question is \"by how much?\".  To answer this, we can create a mythical new format, a copy of `e4m3`, but with `has_subnormals` set to true."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "e4m3_no_subnormals = copy.copy(format_info_ocp_e4m3)\n",
    "e4m3_no_subnormals.has_subnormals = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now compute the dynamic range with and without:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dynamic range with subnormals = 17.807354922057606\n",
      "Dynamic range without subnormals = 15.637429920615292\n",
      "Ratio = 4.5\n"
     ]
    }
   ],
   "source": [
    "dr_with = compute_dynamic_range(format_info_ocp_e4m3)\n",
    "dr_without = compute_dynamic_range(e4m3_no_subnormals)\n",
    "\n",
    "print(f\"Dynamic range with subnormals = {dr_with}\")\n",
    "print(f\"Dynamic range without subnormals = {dr_without}\")\n",
    "print(f\"Ratio = {2**(dr_with - dr_without):.1f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml_dtypes",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
