How does pytorch#39;s nn.Module register submodule?(pytorch 的 nn.Module 如何注册子模块?)
问题描述
当我阅读 torch.nn.Module 的源代码(python)时,我发现属性 self._modules
已在许多函数中使用,例如self.modules(), self.children()
等等 但是我没有找到任何函数更新它.那么,self._modules
将在哪里更新?另外pytorch的nn.Module
是如何注册子模块的?
When I read the source code(python) of torch.nn.Module , I found the attribute
self._modules
has been used in many functions likeself.modules(), self.children()
, etc. However, I didn't find any functions updating it. So, where will theself._modules
be updated? Furthermore, how does pytorch'snn.Module
register submodule?
class Module(object):
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
def named_modules(self, memo=None, prefix=''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m
推荐答案
通常通过为 nn.module
的实例设置属性来注册模块和参数.特别是,这种行为是通过对__setattr__
方法进行cuatomizing来实现的:
The modules and parameters are usually registered by setting an attribute for an instance of nn.module
.
Particularly, this kind of behavior is implemented by cuatomizing the __setattr__
method:
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
参见 https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py 找到这个方法.
See https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py to find this method.
这篇关于pytorch 的 nn.Module 如何注册子模块?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!