import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
# Create an 8x3x8 neural network
model = Sequential([
Dense(8, activation='relu', input_dim=8), # Input layer with 8 neurons
Dense(3, activation='relu'), # Hidden layer with 3 neurons
Dense(8, activation='sigmoid') # Output layer with 8 neurons
])
# Display the model summary
model.summary()
Loading...
# 展示神经网络的形状结构
%pip install pydot graphviz
# Install the graphviz system package
import sys
import os
# Install graphviz based on the operating system
if sys.platform.startswith('darwin'): # macOS
!brew install graphviz || echo "Homebrew not available or graphviz already installed"
elif sys.platform.startswith('linux'):
!apt-get update && apt-get install -y graphviz || echo "apt not available or graphviz already installed"
# Alternative for other Linux distributions
!yum install -y graphviz || echo "yum not available or graphviz already installed"
from IPython.display import SVG
from tensorflow.keras.utils import model_to_dot
dot = model_to_dot(model, show_shapes=True)
dot.set('dpi', '60')
print(dot)
SVG(dot.create(prog='dot', format='svg'))
Loading...
# Compile the model
model.compile(optimizer='adam',
loss='mse',
metrics=['accuracy'])
import numpy as np
# 创建训练数据 - 每个输入是一个8维的one-hot向量
X_train = np.array([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1]
])
# 输出数据与输入相同(恒等函数)
y_train = X_train.copy()
# 训练模型
history = model.fit(X_train, y_train, epochs=5000, verbose=0)
# 评估模型性能
loss, accuracy = model.evaluate(X_train, y_train, verbose=0)
print(f"最终损失:{loss:.4f}, 准确率:{accuracy:.4f}")
# 测试模型
predictions = model.predict(X_train)
print("\n预测结果:")
for i in range(len(X_train)):
print(f"输入:{X_train[i]} -> 预测输出:{predictions[i]},期望: {y_train[i]}")
print("-" * 50)
最终损失:0.0009, 准确率:1.0000
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
预测结果:
输入:[1 0 0 0 0 0 0 0] -> 预测输出:[9.8932779e-01 3.8438298e-13 2.0175681e-11 2.1421800e-18 2.4873743e-02
2.1679604e-18 5.1605110e-03 3.8065081e-03],期望: [1 0 0 0 0 0 0 0]
--------------------------------------------------
输入:[0 1 0 0 0 0 0 0] -> 预测输出:[2.89736971e-14 9.89990175e-01 2.11480403e-23 9.79433209e-14
1.01906825e-02 1.35070120e-03 3.39770681e-26 5.43721905e-03],期望: [0 1 0 0 0 0 0 0]
--------------------------------------------------
输入:[0 0 1 0 0 0 0 0] -> 预测输出:[4.0145537e-13 3.9973972e-21 9.8758566e-01 4.7793160e-03 2.7272351e-02
4.7150066e-12 2.4005007e-03 1.0467570e-21],期望: [0 0 1 0 0 0 0 0]
--------------------------------------------------
输入:[0 0 0 1 0 0 0 0] -> 预测输出:[5.5531644e-21 2.1546344e-13 3.3683470e-03 9.8619699e-01 3.0464206e-02
2.6401365e-03 4.3417571e-14 7.2851368e-22],期望: [0 0 0 1 0 0 0 0]
--------------------------------------------------
输入:[0 0 0 0 1 0 0 0] -> 预测输出:[0.03348933 0.02771423 0.04098021 0.0336943 0.7992507 0.04338678
0.02854644 0.03776949],期望: [0 0 0 0 1 0 0 0]
--------------------------------------------------
输入:[0 0 0 0 0 1 0 0] -> 预测输出:[4.4649845e-22 3.1539593e-03 1.0899286e-13 8.4976470e-03 2.2211868e-02
9.8726237e-01 1.8719188e-24 5.6432147e-14],期望: [0 0 0 0 0 1 0 0]
--------------------------------------------------
输入:[0 0 0 0 0 0 1 0] -> 预测输出:[4.7646235e-03 2.3268191e-21 2.4147548e-03 8.1610424e-13 1.8457059e-02
1.3072618e-19 9.8956048e-01 3.4051971e-13],期望: [0 0 0 0 0 0 1 0]
--------------------------------------------------
输入:[0 0 0 0 0 0 0 1] -> 预测输出:[2.9792413e-03 4.3300455e-03 5.9100810e-18 7.5542978e-17 3.7516832e-02
2.5702881e-10 2.8379615e-13 9.8470265e-01],期望: [0 0 0 0 0 0 0 1]
--------------------------------------------------