I'm trying to implement a regularization term for the loss function of a neural network.
from torch import nn
import torch
import numpy as np
reg_sig = torch.randn([32, 9, 5])
reg_adj = torch.randn([32, 9, 9, 4])
Maug = reg_adj.shape[0]
n_node = 9
n_bond_features = 4
n_atom_features = 5
SM_f = nn.Softmax(dim=2)
SM_W = nn.Softmax(dim=3)
p_f = SM_f(reg_sig)
p_W = SM_W(reg_adj)
Sig = nn.Sigmoid()
q = 1 - p_f[:, :, 4]
A = 1 - p_W[:, :, :, 0]
A_0 = torch.eye(n_node)
A_0 = A_0.reshape((1, n_node, n_node))
A_i = A
B = A_0.repeat(reg_sig.size(0), 1, 1)
for i in range(1, n_node):
A_i = Sig(100 * (torch.bmm(A_i, A) - 0.5))
B += A_i
C = Sig(100 * (B - 0.5))
reg_g_ij = torch.randn([reg_sig.size(0), n_node, n_node])
for i in range(n_node):
for j in range(n_node):
reg_g_ij[:, i, j] = q[:, i] * q[:, j] * (1 - C[:, i, j]) + (1 - q[:, i] * q[:, j]) * C[:, i, j]
I believe that my implementation is computationally not efficient and would like to have some suggestions on which parts I can change. Specifically, I would like to get rid of the loops and do them using matrix operations if possible. Any suggestions or working examples or links to useful torch functions would be appreciated