{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# add numbers\nThis example illustrates how to use 'self-attention' mechanism op top of LSTM to make\nthe prediction of LSTM interpretable.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import ai4water\nai4water.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\ntf.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nfrom ai4water import Model\n\nfrom easy_mpl import imshow\n\nimport numpy as np\nnp.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "seq_len = 20\nnum_inputs = 2\nmodel = Model(\n    model = {\"layers\": {\n        \"Input_1\": {\"shape\": (seq_len, num_inputs)},\n        \"AttentionLSTM\": {\"num_inputs\": num_inputs, \"lstm_units\": 16},\n        \"Dense\": 1\n    }},\n)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.inputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def add_numbers_before_delimiter(n: int,\n                                 seq_length: int,\n                                 index_1: int = None) -> (np.array, np.array):\n    \"\"\"\n    Task: Add all the numbers that come before the delimiter.\n    x = [1, 2, 3, 0, 4, 5, 6, 7, 8, 9]. Result is y =  6.\n    @param n: number of samples in (x, y).\n    @param seq_length: length of the sequence of x.\n    @param index_1: index of the number that comes after the first 0.\n    @return: returns two numpy.array x and y of shape (n, seq_length, 1) and (n, 1).\n    \"\"\"\n    x = np.random.uniform(0, 1, (n, seq_length))\n    y = np.zeros(shape=(n, 1))\n    for i in range(len(x)):\n        if index_1 is None:\n            a = np.random.choice(range(1, len(x[i])), size=1, replace=False)[0]\n        else:\n            a = index_1\n        y[i] =  np.sum(x[i, 0:a])\n        x[i, a] = 0.0\n\n    x = np.expand_dims(x, axis=-1)\n    return x, y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_train1, y_train1 = add_numbers_before_delimiter(20_00, seq_len)\nx_train2, y_train2 = add_numbers_before_delimiter(20_00, seq_len)\nx_train = np.concatenate([x_train1, x_train2], axis=2)\ny_train = y_train1 + y_train2\nx_train.shape, y_train.shape\n\nx_val1, y_val1 = add_numbers_before_delimiter(5_00, seq_len)\nx_val2, y_val2 = add_numbers_before_delimiter(5_00, seq_len)\nx_val = np.concatenate([x_val1, x_val2], axis=2)\ny_val = y_val1 + y_val2\nx_val.shape, y_val.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "h = model.fit(x=x_train, y=y_train,\n              validation_data=(x_val, y_val),\n              epochs=1000, verbose=1\n              )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Now prepare test data\n\nx_test1, y_test1 = add_numbers_before_delimiter(5_00, seq_len)\nx_test2, y_test2 = add_numbers_before_delimiter(5_00, seq_len)\nx_test = np.concatenate([x_test1, x_test2], axis=2)\ny_test = y_test1 + y_test2\nx_test.shape, y_test.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "attention_weights = model.get_attention_lstm_weights(x_test)\n\n\nattention_weights.keys()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "num_examples = 10  # number of examples to show\n\nfig, axis = plt.subplots(2, sharex=\"all\")\n\nimshow(attention_weights[\"self_attention\"][0:num_examples],\n       ylabel=\"Examples\",\n       title=\"Important steps from Attention\", cmap=\"hot\",\n      ax=axis[0], show=False)\n\na = x_test1[0:num_examples].reshape(-1, seq_len)\na = np.where(a==0.0, 1.0, 0.0)\nimshow(a, ylabel=\"Examples\",\n       xlabel=\"Sequence Length (lookback steps)\",\n       xticklabels=np.arange(20),\n       title=\"Actual important steps\", cmap=\"hot\",\n      ax=axis[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, axis = plt.subplots(2, sharex=\"all\")\n\nimshow(attention_weights[\"self_attention_1\"][0:num_examples],\n       ylabel=\"Examples\",\n       title=\"Important steps from Attention\", cmap=\"hot\",\n      ax=axis[0], show=False)\n\na = x_test2[0:num_examples].reshape(-1, seq_len)\na = np.where(a==0.0, 1.0, 0.0)\nimshow(a, ylabel=\"Examples\",\n       xlabel=\"Sequence Length (lookback steps)\",\n       xticklabels=np.arange(20),\n       title=\"Actual important steps\", cmap=\"hot\",\n      ax=axis[1])"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}