Python之Numpy与Pytorch彼此转换时的坑
感兴趣的小伙伴,下面一起跟随php教程的雯雯来看看吧!
前言 &em
这篇文章主要为大家详细介绍了Python之Numpy与Pytorch彼此转换时的坑,具有一定的参考价值,可以用来参考一下。
感兴趣的小伙伴,下面一起跟随php教程的雯雯来看看吧!
前言
最近使用 Numpy包与Pytorch写神经网络时,经常需要两者彼此转换,故用此笔记记录码代码时踩(菜)过的坑,网上有人说:
Pytorch 又被称为 GPU 版的 Numpy,二者的许多功能都有良好的一一对应。
但在使用时还是得多多注意,一个不留神就陷入到了 一根烟一杯酒,一个Bug找一宿 的地步。
1.1、numpy ——> torch
使用 torch.from_numpy() 转换,需要注意,两者共享内存。例子如下:
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | <code> import torch import numpy as np a = np. array ([1,2,3]) b = torch.from_numpy(a) np.add(a, 1, out=a) print ( '转换后a' , a) print ( '转换后b' , b) # 显示 转换后a [2 3 4] 转换后b tensor([2, 3, 4], dtype=torch.int32) </code> |
解决Numpy与Pytorch彼此转换时的坑
1.2、torch——> numpy
使用 .numpy() 转换,同样,两者共享内存。例子如下:
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | <code> import torch import numpy as np a = torch.zeros((2, 3), dtype=torch.float) c = a.numpy() np.add(c, 1, out=c) print ( 'a:' , a) print ( 'c:' , c) # 结果 a: tensor([[1., 1., 1.], [1., 1., 1.]]) c: [[1. 1. 1.] [1. 1. 1.]] </code> |
解决Numpy与Pytorch彼此转换时的坑
需要注意的是,如果将程序中的 np.add(c, 1, out=c) 改成 c = c + 1 会发现两者貌似不共享内存了,其实不然,原因是后者相当于改变了 c 的存储地址。可以使用 id(c) 发现c的内存位置变了。
补充:pytorch中tensor数据和numpy数据转换中注意的一个问题
在pytorch中,把numpy.array数据转换到张量tensor数据的常用函数是torch.from_numpy(array)或者torch.Tensor(array),第一种函数更常用。
下面通过代码看一下区别:
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | <code> import numpy as np import torch a=np.arange(6,dtype=int).reshape(2,3) b=torch.from_numpy(a) c=torch.Tensor(a) a[0][0]=10 print (a, '\n' ,b, '\n' ,c) [[10 1 2] [ 3 4 5]] tensor([[10, 1, 2], [ 3, 4, 5]], dtype=torch.int32) tensor([[0., 1., 2.], [3., 4., 5.]]) c[0][0]=10 print (a, '\n' ,b, '\n' ,c) [[10 1 2] [ 3 4 5]] tensor([[10, 1, 2], [ 3, 4, 5]], dtype=torch.int32) tensor([[10., 1., 2.], [ 3., 4., 5.]]) print (b.type()) torch.IntTensor print (c.type()) torch.FloatTensor </code> |
解决Numpy与Pytorch彼此转换时的坑
可以看出修改数组a的元素值,张量b的元素值也改变了,但是张量c却不变。修改张量c的元素值,数组a和张量b的元素值都不变。
这说明torch.from_numpy(array)是做数组的浅拷贝,torch.Tensor(array)是做数组的深拷贝。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持php教程。
注:关于Python之Numpy与Pytorch彼此转换时的坑的内容就先介绍到这里,更多相关文章的可以留意