🌑

Jenqyang

Pytorch网络可视化方式

方式一:print()

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) 
        self.conv2 = nn.Conv2d(6, 16, 5)  
        self.fc1   = nn.Linear(16*5*5, 120)  
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x): 
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) 
        x = F.max_pool2d(F.relu(self.conv2(x)), 2) 
        x = x.view(x.size()[0], -1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)        
        return x
net = Net()
print(net)
Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

方式二:summary

from torchsummary import summary
testnet = MLPregression()
summary(testnet, input_size=(1, 8)) # 表示1个样本,每个样本有8个特征
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1               [-1, 1, 100]             900
            Linear-2               [-1, 1, 100]          10,100
            Linear-3                [-1, 1, 50]           5,050
            Linear-4                 [-1, 1, 1]              51
================================================================
Total params: 16,101
Trainable params: 16,101
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.06
Estimated Total Size (MB): 0.06
----------------------------------------------------------------

https://blog.csdn.net/weixin_44979150/article/details/122778521

方式三:make_dot

# 输出网络结构
from torchviz import make_dot
testnet = MLPregression()
x = torch.randn(1, 8).requires_grad_(True)
y = testnet(x)
myMLP_vis = make_dot(y, params=dict(list(testnet.named_parameters()) + [('x', x)]))
myMLP_vis

方式四:tensorboard

方式五:第三方工具

  1. https://github.com/HarisIqbal88/PlotNeuralNet

— Jan 25, 2023

Made with ❤ and Hexo.js at Earth.