{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# ARG prediction\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nfrom easy_mpl import imshow\n\nfrom sklearn.preprocessing import MinMaxScaler\n\nfrom ai4water import Model\nfrom ai4water.datasets import busan_beach\n\nfrom SeqMetrics import RegressionMetrics\n\ndata = busan_beach(inputs=[\n    'tide_cm', 'wat_temp_c', 'air_temp_c', 'sal_psu',\n    'pcp_mm', 'wind_dir_deg', 'wind_speed_mps'\n])\n\nprint(data.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "input features\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "input_features = data.columns.tolist()[0:-1]\nprint(input_features)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "output_features = data.columns.tolist()[-1:]\nprint(output_features)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "seq_len = 14\nnum_inputs = len(input_features)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "build the model\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = Model(\n    model = {\"layers\": {\n        \"Input_1\": {\"shape\": (seq_len, num_inputs)},\n        \"AttentionLSTM\": {\"num_inputs\": num_inputs, \"lstm_units\": 10},\n        \"Dense\": 1\n    }},\n    x_transformation='minmax',\n    y_transformation=\"log\",\n    input_features=input_features,\n    output_features = output_features,\n    train_fraction=1.0,\n    split_random=True,\n    ts_args={\"lookback\": seq_len},\n    lr=0.005,\n    batch_size=24,\n    epochs=50000,\n    patience=1000,\n    monitor=[\"nse\"],\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "train the model\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "h = model.fit(data=data, verbose=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_val, y_val = model.validation_data(data=data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "check performance\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pred_val = model.predict_on_validation_data(data=data, process_results=False)\nmetrics = RegressionMetrics(y_val, pred_val)\nmetrics.nse(), metrics.r2()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "attention_weights = model.get_attention_lstm_weights(x_val)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "plot attention maps\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "num_examples = 40  # number of examples to show\n\n\nfor idx, key in enumerate(attention_weights.keys()):\n\n    fig, axis = plt.subplots(2, sharex=\"all\", figsize=(6, 8))\n\n    val = attention_weights[key][0:num_examples].T\n    val = MinMaxScaler().fit_transform(val)\n    imshow(val, colorbar=True, ax=axis[0], show=False, cmap=\"hot\",\n           title=f\"Attention map for {input_features[idx]}\")\n    imshow(x_val[:, :, idx][0:num_examples].T, colorbar=True,\n        cmap=\"hot\", ax=axis[1], show=False, title=input_features[idx])\n    plt.tight_layout()\n    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Training data\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_train, y_train = model.training_data(data=data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "check performance\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pred_train = model.predict_on_training_data(data=data, process_results=False)\nmetrics = RegressionMetrics(y_train, pred_train)\nmetrics.nse(), metrics.r2()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.data_config['allow_nan_labels'] = 2\nmodel.data_config['split_random'] = False\nx, y = model.all_data(data=data)\n\nattention_weights_tr = model.get_attention_lstm_weights(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "plot attention maps\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "num_examples = 1400  # number of examples to show\n\nfor idx, key in enumerate(attention_weights_tr.keys()):\n\n    fig, axis = plt.subplots(2, sharex=\"all\", figsize=(8, 6),\n                             gridspec_kw={\"hspace\": 0.1})\n\n    val = attention_weights_tr[key][0:num_examples].T\n    val = MinMaxScaler().fit_transform(val)\n    imshow(val, colorbar=True, ax=axis[0], show=False, cmap=\"hot\",\n           title=f\"Attention map for {input_features[idx]}\",\n           aspect=\"auto\")\n    imshow(x[:, :, idx][0:num_examples].T, colorbar=True,\n        cmap=\"hot\", ax=axis[1], show=False, title=input_features[idx],\n           aspect=\"auto\")\n    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "plot attention weights without normalization\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for idx, key in enumerate(attention_weights_tr.keys()):\n\n    fig, axis = plt.subplots(2, sharex=\"all\", figsize=(8, 6),\n                             gridspec_kw={\"hspace\": 0.1})\n\n    val = attention_weights_tr[key][0:num_examples].T\n    imshow(val, colorbar=True, ax=axis[0], show=False, cmap=\"hot\",\n           title=f\"Attention map for {input_features[idx]}\",\n           aspect=\"auto\")\n    imshow(x[:, :, idx][0:num_examples].T, colorbar=True,\n        cmap=\"hot\", ax=axis[1], show=False, title=input_features[idx],\n           aspect=\"auto\")\n    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}