手写单层RNN网络,后续更新

news/2025/2/3 4:27:32 标签: rnn, 人工智能, 深度学习

文章目录

  • 1. 原理
  • 2. pytorch 源码,只是测试版,后续持续优化

1. 原理

根据如下公式,简单的手写实现单层的RNN神经网络,加强代码功能和对网络的理解能力
在这里插入图片描述

2. pytorch 源码,只是测试版,后续持续优化

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)
torch.manual_seed(23435)

if __name__ == "__main__":
    run_code = 0
    input_size = 4
    hidden_size = 3
    num_layers = 1
    batch_first = True
    single_rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first)
    print(single_rnn)
    for name in single_rnn.named_parameters():
        print(name)

    single_rnn_weight_ih_l0 = single_rnn.weight_ih_l0
    single_rnn_weight_hh_l0 = single_rnn.weight_hh_l0
    single_rnn_bias_ih_l0 = single_rnn.bias_ih_l0
    single_rnn_bias_hh_l0 = single_rnn.bias_hh_l0

    # print(f"single_rnn_weight_ih_l0=\n{single_rnn_weight_ih_l0}")
    # input --> batch_size,seq_len,feature_map
    in_batch_size = 1
    in_seq_len = 2
    in_feature_map = input_size
    input_matrix = torch.randn(in_batch_size, in_seq_len, in_feature_map)
    output_matrix, output_hn = single_rnn(input_matrix)
    print(f"output_matrix=\n{output_matrix}")
    print(f"output_hn=\n{output_hn}")
    test_output0 = input_matrix @ single_rnn_weight_ih_l0.T + single_rnn_bias_ih_l0
    ht_1 = torch.zeros_like(test_output0)
    print(f"ht_1=\n{ht_1}")
    print(f"ht_1.shape=\n{ht_1.shape}")
    test_output1 = ht_1 @ single_rnn_weight_hh_l0.T + single_rnn_bias_hh_l0
    test_output = torch.tanh(test_output1 + test_output0)
    ht_1[:,1, :] = test_output[:,0, :]
    test_output1 = ht_1 @ single_rnn_weight_hh_l0.T + single_rnn_bias_hh_l0
    test_output = torch.tanh(test_output1 + test_output0)
    print(f"test_output=\n{test_output}")
    print(f"test_output.shape=\n{test_output.shape}")
  • 结果:经计算,通过pytorch官方的API输出的结果和自定义的结果一致!!!
RNN(4, 3, batch_first=True)
('weight_ih_l0', Parameter containing:
tensor([[ 0.413,  0.044,  0.243,  0.171],
        [-0.093,  0.250, -0.499, -0.450],
        [-0.571,  0.220,  0.464, -0.154]], requires_grad=True))
('weight_hh_l0', Parameter containing:
tensor([[-0.403,  0.165, -0.244],
        [ 0.216, -0.511, -0.441],
        [ 0.133,  0.278, -0.211]], requires_grad=True))
('bias_ih_l0', Parameter containing:
tensor([ 0.115, -0.493,  0.555], requires_grad=True))
('bias_hh_l0', Parameter containing:
tensor([-0.309, -0.504,  0.311], requires_grad=True))
output_matrix=
tensor([[[ 0.243, -0.467, -0.554],
         [-0.013, -0.802, -0.490]]], grad_fn=<TransposeBackward1>)
output_hn=
tensor([[[-0.013, -0.802, -0.490]]], grad_fn=<StackBackward0>)
ht_1=
tensor([[[0., 0., 0.],
         [0., 0., 0.]]])
ht_1.shape=
torch.Size([1, 2, 3])
test_output=
tensor([[[ 0.243, -0.467, -0.554],
         [-0.013, -0.802, -0.490]]], grad_fn=<TanhBackward0>)
test_output.shape=
torch.Size([1, 2, 3])

http://www.niftyadmin.cn/n/5840478.html

相关文章

DeepSeek的提示词使用说明

一、DeepSeek概述 DeepSeek是一款基于先进推理技术的大型语言模型&#xff0c;能够根据用户提供的简洁提示词生成高质量、精准的内容。在实际应用中&#xff0c;DeepSeek不仅能够帮助用户完成各类文案撰写、报告分析、市场研究等工作&#xff0c;还能够生成结构化的内容&#…

动手学强化学习(四)——蒙特卡洛方法

一、蒙特卡洛方法 蒙特卡洛方法是一种无模型&#xff08;Model-Free&#xff09;的强化学习算法&#xff0c;它通过直接与环境交互采样轨迹&#xff08;episodes&#xff09;来估计状态或动作的价值函数&#xff08;Value Function&#xff09;&#xff0c;而不需要依赖环境动态…

处理 .gitignore 未忽略文件夹问题

本地删除缓存 例如 .idea 文件夹被其他同事误提交&#xff0c;那么他本地执行以下代码 git rm -r --cached .idea对应本地再提交即可

STM32 LED呼吸灯

接线图&#xff1a; 这里将正极接到PA0引脚上&#xff0c;负极接到GND&#xff0c;这样就高电平点亮LED&#xff0c;低电平熄灭。 占空比越大&#xff0c;LED越亮&#xff0c;占空比越小&#xff0c;LED越暗 PWM初始化配置 输出比较函数介绍&#xff1a; 用这四个函数配置输…

牛客周赛 Round 78

题目目录 A-时间表查询&#xff01;解题思路参考代码 B-一起做很甜的梦&#xff01;解题思路参考代码 C-翻之解题思路参考代码 D-乘之解题思路参考代码 E-在树上游玩解题思路参考代码 A-时间表查询&#xff01; \hspace{15pt} 今天是2025年1月25日&#xff0c;今年的六场牛客寒…

家庭财务管理系统的设计与实现

标题:家庭财务管理系统的设计与实现 内容:1.摘要 摘要&#xff1a;随着家庭经济的日益复杂&#xff0c;家庭财务管理变得越来越重要。本文旨在设计并实现一个功能强大的家庭财务管理系统&#xff0c;以帮助用户更好地管理家庭财务。通过对家庭财务管理需求的分析&#xff0c;我…

MySQL 如何深度分页问题

在实际的数据库应用场景中&#xff0c;我们常常会遇到需要进行分页查询的需求。对于少量数据的分页查询&#xff0c;MySQL 可以轻松应对。然而&#xff0c;当我们需要进行深度分页&#xff08;即从大量数据的中间位置开始获取少量数据&#xff09;时&#xff0c;就会面临性能严…

【Uniapp-Vue3】获取用户状态栏高度和胶囊按钮高度

在项目目录下创建一个utils文件&#xff0c;并在里面创建一个system.js文件。 在system.js中配置如下代码&#xff1a; const SYSTEM_INFO uni.getSystemInfoAsync();// 返回状态栏高度 export const getStatusBarHeight ()> SYSTEM_INFO.statusBarHeight || 15;// 返回胶…