Created
August 9, 2023 18:26
-
-
Save sampathweb/295ea4c8b7fa964e8f58728f27e293ac to your computer and use it in GitHub Desktop.
GitHub-Issue-Keras-Core-684-load-issue.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPxuuIj6TpGgcQ0GbHOlRTN", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/sampathweb/295ea4c8b7fa964e8f58728f27e293ac/github-issue-keras-core-684-load-issue.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "zRuBJxNmh5M-", | |
"outputId": "a033151a-2ee2-438d-df79-d27e53f3bf6d" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/880.1 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/880.1 kB\u001b[0m \u001b[31m937.8 kB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m337.9/880.1 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━\u001b[0m \u001b[32m798.7/880.1 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m880.1/880.1 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
], | |
"source": [ | |
"!pip install -q keras-core" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"\n", | |
"os.environ[\"KERAS_BACKEND\"] = \"torch\"\n", | |
"\n", | |
"import torch\n", | |
"import keras_core as keras\n", | |
"import numpy as np" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "bxc8M84Jh8tK", | |
"outputId": "5862491b-04ab-4b5b-edb8-cc69831a45b3" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Using PyTorch backend.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@keras.saving.register_keras_serializable()\n", | |
"class CustomModel(keras.Model):\n", | |
" def train_step(self, data):\n", | |
" x, y = data\n", | |
" self.zero_grad()\n", | |
" y_pred = self(x, training=True) # Forward pass\n", | |
" loss = self.compute_loss(y=y, y_pred=y_pred)\n", | |
" loss.backward()\n", | |
" trainable_weights = [v for v in self.trainable_weights]\n", | |
" gradients = [v.value.grad for v in trainable_weights]\n", | |
"\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" self.optimizer.apply(gradients, trainable_weights)\n", | |
"\n", | |
" for metric in self.metrics:\n", | |
" if metric.name == \"loss\":\n", | |
" metric.update_state(loss)\n", | |
" else:\n", | |
" metric.update_state(y, y_pred)\n", | |
" return {m.name: m.result() for m in self.metrics}" | |
], | |
"metadata": { | |
"id": "QRXDjhtOiGUr" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"inputs = keras.Input(shape=(32,))\n", | |
"outputs = keras.layers.Dense(1)(inputs)\n", | |
"model = CustomModel(inputs, outputs)\n", | |
"model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n", | |
"\n", | |
"x = np.random.random((1000, 32))\n", | |
"y = np.random.random((1000, 1))\n", | |
"model.fit(x, y, epochs=3)\n", | |
"model.save(\"final_model.keras\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "szmioBZAiOVy", | |
"outputId": "7953c1ad-94ce-4ff7-c16b-024eed22c066" | |
}, | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1/3\n", | |
"\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - mean_absolute_error: 1.1603 - loss: 1.6351\n", | |
"Epoch 2/3\n", | |
"\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - mean_absolute_error: 0.7104 - loss: 0.7184\n", | |
"Epoch 3/3\n", | |
"\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - mean_absolute_error: 0.5161 - loss: 0.3951\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"final_model = keras.models.load_model(\"final_model.keras\")\n", | |
"final_model.summary()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 193 | |
}, | |
"id": "bpC3zPO4iVji", | |
"outputId": "780c45bd-2be4-46a4-fb03-6e87297ab74d" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1mModel: \"custom_model_2\"\u001b[0m\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"custom_model_2\"</span>\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", | |
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", | |
"│ input_layer_3 (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
"├─────────────────────────────────┼───────────────────────────┼────────────┤\n", | |
"│ dense_3 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m33\u001b[0m │\n", | |
"└─────────────────────────────────┴───────────────────────────┴────────────┘\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", | |
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", | |
"│ input_layer_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
"├─────────────────────────────────┼───────────────────────────┼────────────┤\n", | |
"│ dense_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">33</span> │\n", | |
"└─────────────────────────────────┴───────────────────────────┴────────────┘\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m33\u001b[0m (1.03 KB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">33</span> (1.03 KB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m33\u001b[0m (1.03 KB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">33</span> (1.03 KB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "b5iOkxmGibmS" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello, it works fine if it is loaded in the same run that generates the model. Try loading it in another colab, i.e. save the model, upload it and try loading it from scratch. Thanks for your time