torch.sparse

警告 该API目前是实验性的,可能在不久的将来会发生变化。

torch支持COO(rdinate)格式的稀疏张量,可以有效地存储和处理大多数元素为零的张量。

稀疏张量被表示为一对致密张量:值的张量和2D张量的索引。可以通过提供这两个张量来构造稀疏张量,以及稀疏张量的大小(不能从这些张量推断出!)假设我们要在位置(0,2)处定义具有条目3的稀疏张量, ,位置(1,0)的条目4,位置(1,2)的条目5。我们会写:

  1. >>> i = torch.LongTensor([[0, 1, 1],
  2. [2, 0, 2]])
  3. >>> v = torch.FloatTensor([3, 4, 5])
  4. >>> torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()
  5. 0 0 3
  6. 4 0 5
  7. [torch.FloatTensor of size 2x3]

请注意,LongTensor的输入不是索引元组的列表。如果要以这种方式编写索引,则在将它们传递给稀疏构造函数之前,应该进行转置:

  1. >>> i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
  2. >>> v = torch.FloatTensor([3, 4, 5 ])
  3. >>> torch.sparse.FloatTensor(i.t(), v, torch.Size([2,3])).to_dense()
  4. 0 0 3
  5. 4 0 5
  6. [torch.FloatTensor of size 2x3]

您还可以构建混合稀疏张量,其中只有第一个n维是稀疏的,其余的维度是密集的。

  1. >>> i = torch.LongTensor([[2, 4]])
  2. >>> v = torch.FloatTensor([[1, 3], [5, 7]])
  3. >>> torch.sparse.FloatTensor(i, v).to_dense()
  4. 0 0
  5. 0 0
  6. 1 3
  7. 0 0
  8. 5 7
  9. [torch.FloatTensor of size 5x2]

可以通过指定一个空的稀疏张量来构建一个空的稀疏张量:

  1. print torch.sparse.FloatTensor(2, 3)
  2. # FloatTensor of size 2x3 with indices:
  3. # [torch.LongTensor with no dimension]
  4. # and values:
  5. # [torch.FloatTensor with no dimension]

注意:

我们的稀疏张量格式允许未被缩小的稀疏张量,其中索引中可能有重复的坐标; 在这种情况下,解释是该索引处的值是所有重复值条目的总和。非协调张量允许我们更有效地实施某些运营商。

在大多数情况下,您不必关心稀疏张量是否合并,因为大多数操作将在合并或未被缩小的稀疏张量的情况下工作相同。但是,您可能需要关心两种情况。

首先,如果您反复执行可以产生重复条目(例如torch.sparse.FloatTensor.add())的操作,则应偶尔将您的稀疏张量合并,以防止它们变得太大。

其次,一些运营商将取决于它们是否被合并或不产生不同的值(例如, torch.sparse.FloatTensor._values()和 torch.sparse.FloatTensor._indices(),以及 torch.Tensor._sparse_mask())。这些运算符前面加上一个下划线,表示它们显示内部实现细节,应谨慎使用,因为与聚结的稀疏张量一起工作的代码可能无法与未被缩放的稀疏张量一起使用; 一般来说,在与这些运营商合作之前明确地合并是最安全的。

例如,假设我们想通过直接操作来实现一个操作符torch.sparse.FloatTensor._values()。随着乘法分布的增加,标量的乘法可以以明显的方式实现; 然而,平方根不能直接实现,因为(如果你被赋予了未被缩放的张量,那将是什么)。sqrt(a + b) != sqrt(a) + sqrt(b)

class torch.sparse.FloatTensor

  • add()
  • add_()
  • clone()
  • dim()
  • div()
  • div_()
  • get_device()
  • hspmm()
  • mm()
  • mul()
  • mul_()
  • resizeAs_()
  • size()
  • spadd()
  • spmm()
  • sspaddmm()
  • sspmm()
  • sub()
  • sub_()
  • t_()
  • toDense()
  • transpose()
  • transpose_()
  • zero_()
  • coalesce()
  • is_coalesced()
  • _indices()
  • _values()
  • _nnz()

译者署名

用户名 头像 职能 签名
Song torch.sparse - 图1 翻译 人生总要追求点什么