为 pytorch 自定义 C 扩展

译者:@飞龙

作者: Soumith Chintala

第一步. 准备你的 C 代码

首先, 你需要编写你的 C 函数.

下面你可以找到模块的正向和反向函数的示例实现, 它将两个输入相加.

在你的 .c 文件中, 你可以使用 #include <TH/TH.h> 直接包含 TH, 以及使用 #include <THC/THC.h> 包含 THC.

ffi (外来函数接口) 工具会确保编译器可以在构建过程中找到它们.

  1. /* src/my_lib.c */
  2. #include <TH/TH.h>
  3. int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
  4. THFloatTensor *output)
  5. {
  6. if (!THFloatTensor_isSameSizeAs(input1, input2))
  7. return 0;
  8. THFloatTensor_resizeAs(output, input1);
  9. THFloatTensor_cadd(output, input1, 1.0, input2);
  10. return 1;
  11. }
  12. int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
  13. {
  14. THFloatTensor_resizeAs(grad_input, grad_output);
  15. THFloatTensor_fill(grad_input, 1);
  16. return 1;
  17. }

代码没有任何限制, 除了你必须准备单个头文件, 它会列出所有你想要从 Python 调用的函数.

它会由 ffi 用于生成合适的包装.

  1. /* src/my_lib.h */
  2. int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
  3. int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

现在, 你需要一个超短的文件, 它会构建你的自定义扩展:

  1. # build.py
  2. from torch.utils.ffi import create_extension
  3. ffi = create_extension(
  4. name='_ext.my_lib',
  5. headers='src/my_lib.h',
  6. sources=['src/my_lib.c'],
  7. with_cuda=False
  8. )
  9. ffi.build()

第二步: 在你的 Python 代码中包含它

你运行它之后, pytorch 会创建一个 _ext 目录, 并把 my_lib 放到里面.

包名称可以在最终模块名称之前, 包含任意数量的包 (包括没有). 如果构建成功, 你可以导入你的扩展, 就像普通的 Python 文件.

  1. # functions/add.py
  2. import torch
  3. from torch.autograd import Function
  4. from _ext import my_lib
  5. class MyAddFunction(Function):
  6. def forward(self, input1, input2):
  7. output = torch.FloatTensor()
  8. my_lib.my_lib_add_forward(input1, input2, output)
  9. return output
  10. def backward(self, grad_output):
  11. grad_input = torch.FloatTensor()
  12. my_lib.my_lib_add_backward(grad_output, grad_input)
  13. return grad_input
  1. # modules/add.py
  2. from torch.nn import Module
  3. from functions.add import MyAddFunction
  4. class MyAddModule(Module):
  5. def forward(self, input1, input2):
  6. return MyAddFunction()(input1, input2)
  1. # main.py
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. from modules.add import MyAddModule
  6. class MyNetwork(nn.Module):
  7. def __init__(self):
  8. super(MyNetwork, self).__init__()
  9. self.add = MyAddModule()
  10. def forward(self, input1, input2):
  11. return self.add(input1, input2)
  12. model = MyNetwork()
  13. input1, input2 = Variable(torch.randn(5, 5)), Variable(torch.randn(5, 5))
  14. print(model(input1, input2))
  15. print(input1 + input2)