torch.testing¶
- 
torch.testing.assert_close(实际、预期、*、allow_subclasses=True、rtol=无、atol=无、equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=无)[来源]¶
- 断言 并且很接近。 - actual- expected- 如果 和 是跨步的、非量化的、实值和有限的,则它们被认为是接近的,如果 - actual- expected- 非有限值 ( 和 ) 仅在相等时被视为接近值。的 是 如果 为 ,则仅被视为彼此相等。 - -inf- inf- NaN- equal_nan- True- 此外,只有当它们具有相同的 - - device(如果是), - (如果是), - (如果是 ),以及 - 步幅 (如果为 )。 如果 or 是元张量,则仅执行属性检查。- check_device- True- dtype- check_dtype- True- layout- check_layout- True- check_stride- True- actual- expected- 如果 和 是稀疏的(具有 COO、CSR、CSC、BSR 或 BSC 布局),则它们的跨步成员为 单独检查。指数,即 COO、CSR 和 BSR、 或 和 分别用于 CSC 和 BSC 布局, 始终检查是否相等,而根据上述定义检查值是否接近。 - actual- expected- indices- crow_indices- col_indices- ccol_indices- row_indices- 如果 和 被量化,则如果它们具有相同的 - actual- expected- qscheme()以及- dequantize()根据 定义。- actual并且可以是- expected- Tensor的 Likes 或任何 tensor-or-scalar-likes- torch.Tensor的 可以构造为- torch.as_tensor().除 Python 标量外,输入类型 必须直接相关。此外,并且可以- actual- expected- Sequence的 或- Mapping,在这种情况下,如果它们的结构匹配并且所有 根据上述定义,它们的元素被认为是接近的。- 注意 - Python 标量是类型关系要求的一个例外,因为它们的 ,即 - type()- int,- float和- complex等效于 tensor-like 的 。因此 可以检查不同类型的 Python 标量,但需要 。- dtype- check_dtype=False- 参数
- actual (Any) (实际输入) – 实际输入。 
- expected (Any) (预期输入) – 预期输入。 
- allow_subclasses (bool) – If (默认) 并且除了 Python 标量之外,直接相关的类型的输入 都是允许的。否则,需要类型相等。 - True
- rtol (Optional[float]) - 相对容差。如果指定,还必须指定。如果省略,则默认 基于 的值是通过下表选择的。 - atol- dtype
- atol (Optional[float]) - 绝对容差。如果指定,还必须指定。如果省略,则默认 基于 的值是通过下表选择的。 - rtol- dtype
- check_device (bool) - 如果 (默认),断言相应的张量位于同一张量上 - True- device.如果禁用此检查,则不同- device的 在进行比较之前移动到 CPU。
- check_dtype (bool) - 如果 (默认) 断言相应的张量具有相同的 。如果此 check 时,具有不同 的 张量将提升为公共 (根据 - True- dtype- dtype- dtype- torch.promote_types()) 进行比较。
- check_layout (bool) - 如果 (默认) 断言相应的张量具有相同的 。如果此 check 时,具有不同 的张量会在 比较。 - True- layout- layout
- check_stride (bool) - 如果和相应的张量是跨步的,则断言它们具有相同的步幅。 - True
- msg (Optional[Union[str, Callable[[str], str]]]) – 在期间发生故障时使用的可选错误消息 比较。也可以作为可调用对象传递,在这种情况下,它将与生成的消息一起调用 应返回新消息。 
 
- 提高
- ValueError – 如果没有 - torch.Tensor可以从 input 构造。
- ValueError – 如果仅指定了 or。 - rtol- atol
- NotImplementedError – 如果张量是元张量。这是一个临时限制,将在 前途。 
- AssertionError – 如果相应的输入不是 Python 标量并且没有直接关系。 
- AssertionError – 如果是 ,但相应的输入不是 Python 标量,并且具有 不同的类型。 - allow_subclasses- False
- AssertionError – 如果输入为 - Sequence的 S 中,但它们的长度不匹配。
- AssertionError – 如果输入为 - Mapping的密钥,但它们的键集不匹配。
- AssertionError – 如果相应的张量不具有相同的 . - shape
- AssertionError – 如果为 ,但相应的张量不具有相同的 。 - check_layout- True- layout
- AssertionError – 如果只有一个相应的张量被量化。 
- AssertionError – 如果相应的张量被量化,但具有不同的 - qscheme()的。
- AssertionError – 如果是 ,但相应的张量不在同一张量上 - check_device- True- device.
- AssertionError – 如果为 ,但相应的张量不具有相同的 。 - check_dtype- True- dtype
- AssertionError – 如果为 ,但相应的跨步张量没有相同的步幅。 - check_stride- True
- AssertionError – 如果根据上述定义,相应张量的值不接近。 
 
 - 下表显示了 default 和 for different 's.如果 不匹配,则使用两个容差的最大值。 - rtol- atol- dtype- dtype- dtype- rtol- atol- float16- 1e-3- 1e-5- bfloat16- 1.6e-2- 1e-5- float32- 1.3e-6- 1e-5- float64- 1e-7- 1e-7- complex32- 1e-3- 1e-5- complex64- 1.3e-6- 1e-5- complex128- 1e-7- 1e-7- quint8- 1.3e-6- 1e-5- quint2x4- 1.3e-6- 1e-5- quint4x2- 1.3e-6- 1e-5- qint8- 1.3e-6- 1e-5- qint32- 1.3e-6- 1e-5- 其他 - 0.0- 0.0- 注意 - assert_close()具有严格的默认设置,具有高度可配置性。鼓励用户 自- partial()it 以适应他们的用例。例如,如果需要进行相等性检查,则可能会 定义一个 默认情况下对 every 使用零公差:- assert_equal- dtype- >>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Absolute difference: 9.000000000000001e-10 Relative difference: 9.0 - 例子 - >>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected) - >>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected) - >>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected) - >>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected) - >>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected) - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False) - >>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True) - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer 
- 
torch.testing.make_tensor(*shape, dtype, device, low=无, high=无, requires_grad=False, 非连续=False, exclude_zero=False)[来源]¶
- 创建具有给定 、 、 和 的张量,并填充有 统一从 中抽取的值。 - shape- device- dtype- [low, high)- 如果指定了 或 ,并且超出 的 的 可表示 finite 值,则它们分别被钳制为最低或最高可表示的有限值。 如果 ,则下表描述了 和 的默认值 , 它们依赖于 。 - low- high- dtype- None- low- high- dtype- dtype- low- high- 布尔类型 - 0- 2- unsigned 整数类型 - 0- 10- 有符号整型 - -9- 10- 浮动类型 - -9- 9- 复杂类型 - -9- 9- 参数
- shape (Tuple[int, ..]) – 定义输出张量形状的单个整数或整数序列。 
- DTYPE ( - torch.dtype) – 返回的张量的数据类型。
- device (Union[str, torch.device]) – 返回张量的设备。 
- low (Optional[Number]) – 设置给定范围的下限(含)。如果提供了数字,则为 固定到给定 dtype 的最小可表示的有限值。When (默认)、 此值是根据 (请参阅上表) 确定的。违约:。 - None- dtype- None
- high (Optional[Number]) – 设置给定范围的上限 (不包括)。如果提供了数字,则为 固定到给定 dtype 的最大可表示有限值。当 (默认) 此值时 是根据 确定的(见上表)。违约:。 - None- dtype- None
- requires_grad (Optional[bool]) – autograd 是否应记录对返回的张量的作。违约:。 - False
- noncontiguous (Optional[bool]) – 如果为 True,则返回的张量将是非连续的。这个参数是 如果构造的 Tensor 少于两个元素,则忽略。 
- exclude_zero (Optional[bool]) – 如果 then 将零替换为 dtype 的小正值 取决于 .对于 bool 和 integer 类型,0 将替换为 1。用于浮动 Point 类型,它被替换为 dtype 的最小正法线数( 对象的“tiny”值),而对于复杂类型,它被替换为复数 其实部和虚部都是复数可表示的最小正正规数 类型。违约。 - True- dtype- dtype- finfo()- False
 
- 提高
- ValueError – 如果为整型 dtype 传递 - requires_grad=True
- ValueError – 如果 . - low > high
- ValueError – 如果 或 是 。 - low- high- nan
- TypeError – 如果此函数不支持。 - dtype
 
 - 例子 - >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0')