In PyTorch code, the implementation could look as follows (bias omitted for simplicity):
import torch
class MyRNNCell(torch.nn.Module):
def __init__(self, rnn_units, input_dim, output_dim):
super(MyRNNCell, self).__init__()
# initialize the weights
self.W_xh = torch.nn.Parameter(torch.randn(input_dim, rnn_units))
self.W_hh = torch.nn.Parameter(torch.randn(rnn_units, rnn_units))
self.W_hy = torch.nn.Parameter(torch.randn(rnn_units, output_dim))
# initialize hidden state to zeros
self.h = torch.zeros(1, rnn_units)
def forward(self, x):
# update the hidden state
self.h = torch.tanh(torch.matmul(x, self.W_xh) + torch.matmul(self.h, self.W_hh))
# compute the output
output = self.h.matmul(self.W_hy)
# return the current output and hidden state
return output, self.h
# test the model
rnn = MyRNNCell(10, 5, 3)
x = torch.randn(1, 5)
output, h = rnn(x)
print(output)Output: tensor([[2.0087, 4.8321, 0.6784]], grad_fn=<MmBackward>)
We can also use the built-in implementation:
rnn = torch.nn.RNN(input_size=5, hidden_size=10, num_layers=1, batch_first=True)
x = torch.randn(1, 1, 5) # batch_size, seq_len, input_size
output, h = rnn(x)
print(output)Output: tensor([[[ 0.2817, 0.2701, -0.0583, 0.0871, -0.3169, -0.1208, -0.1688, -0.5907, -0.5431, 0.2129]]], grad_fn=<TransposeBackward1>)