In PyTorch code, the implementation could look as follows (bias omitted for simplicity):

import torch
 
class MyRNNCell(torch.nn.Module):
    def __init__(self, rnn_units, input_dim, output_dim):
        super(MyRNNCell, self).__init__()
 
        # initialize the weights
        self.W_xh = torch.nn.Parameter(torch.randn(input_dim, rnn_units))
        self.W_hh = torch.nn.Parameter(torch.randn(rnn_units, rnn_units))
        self.W_hy = torch.nn.Parameter(torch.randn(rnn_units, output_dim))
 
        # initialize hidden state to zeros
        self.h = torch.zeros(1, rnn_units)
 
    def forward(self, x):
        # update the hidden state
        self.h = torch.tanh(torch.matmul(x, self.W_xh) + torch.matmul(self.h, self.W_hh))
 
        # compute the output
        output = self.h.matmul(self.W_hy)
 
        # return the current output and hidden state
        return output, self.h
 
# test the model
rnn = MyRNNCell(10, 5, 3)
x = torch.randn(1, 5)
output, h = rnn(x)
print(output)

Output: tensor([[2.0087, 4.8321, 0.6784]], grad_fn=<MmBackward>)

We can also use the built-in implementation:

rnn = torch.nn.RNN(input_size=5, hidden_size=10, num_layers=1, batch_first=True)
x = torch.randn(1, 1, 5) # batch_size, seq_len, input_size
output, h = rnn(x)
print(output)

Output: tensor([[[ 0.2817, 0.2701, -0.0583, 0.0871, -0.3169, -0.1208, -0.1688, -0.5907, -0.5431, 0.2129]]], grad_fn=<TransposeBackward1>)


🌱 Back to Garden