【Pytorch】nn.Linear,nn.Conv

网友投稿 253 2022-11-05


【Pytorch】nn.Linear,nn.Conv

nn.Linear

nn.Conv1d

当​​nn.Conv1d​​​的​​kernel_size=1​​​时,效果与​​nn.Linear​​​相同,不过输入数据格式不同:​

import torchdef count_parameters(model): """Count the number of parameters in a model.""" return sum([p.numel() for p in model.parameters()])conv = torch.nn.Conv1d(3, 32, kernel_size=1)print(count_parameters(conv))# 128linear = torch.nn.Linear(3, 32)print(count_parameters(linear))# 128print(conv.weight.shape)# torch.Size([32, 3, 1])print(linear.weight.shape)# torch.Size([32, 3])# use same initializationlinear.weight = torch.nn.Parameter(conv.weight.squeeze(2))linear.bias = torch.nn.Parameter(conv.bias)tensor = torch.randn(128, 256, 3) # [batch, feature_num,feature_size]permuted_tensor = tensor.permute(0, 2, 1).clone().contiguous() # [batch, feature_size,feature_num]out_linear = linear(tensor)print(out_linear.mean())# tensor(0.0344, grad_fn=)print(out_linear.shape)# torch.Size([128, 256, 32])out_conv = conv(permuted_tensor)print(out_conv.mean())# tensor(0.0344, grad_fn=)print(out_conv.shape)# torch.Size([128, 32, 256])

nn.Conv2d

nn.Conv3d


版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Spring Boot 2.x基础教程之配置元数据的应用
下一篇:【Pytorch】nn.ReLU(inplace=True)
相关文章

 发表评论

暂时没有评论,来抢沙发吧~