开发者

PyTorch小功能之TensorDataset解读

开发者 https://www.devze.com 2023-02-21 09:16 出处:网络 作者: 菜鸟向前冲fighting
目录PyTorch之TensorDatasetPytorch中TensorDataset的快速使用总结PyTorch之TensorDataset TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。
目录
  • PyTorch之TensorDataset
  • Pytorch中TensorDataset的快速使用
  • 总结

PyTorch之TensorDataset

TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。

该类通过每一个 tensor 的第一个维度进行索引。

因此,该类中的 tensor 第一维度必须相等。

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]python, [1, 2, 3], [4, 5编程客栈, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b) 
# 切片输出
print(train_ids[0:2])
print('=' * 80)
# 循环取数据
for x_train, y_label in train_ids:
    print(x_train, y_label)
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, BATch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1):  # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0} javascriptx_data:{1}  label: {2}'.format(i, x_data, label))

运行结果:

(tensor([[1, 2, 3],

        [4, 5, 6]]), tensor([44, 55]))

================================================================================

tensor([1, 2, 3]) tensor(44)

tensor([4, 5, 6]) tensor(55)

tensor([7, 8, 9]) tensor(66)

tensor([1, 2, 3]) tensor(44)

tensor([4, 5, 6]js) tensor(55)

tensor([7, 8, 9]) tensor(66)

tensor([1, 2, 3]) tensor(44)

tensor([开发者_JS培训4, 5, 6]) tensor(55)

tensor([7, 8, 9]) tensor(66)

tensor([1, 2, 3]) tensor(44)

tensor([4, 5, 6]) tensor(55)

tensor([7, 8, 9]) tensor(66)

================================================================================

 batch:1 x_data:tensor([[1, 2, 3],

        [1, 2, 3],

        [4, 5, 6],

        [4, 5, 6]])  label: tensor([44, 44, 55, 55])

 batch:2 x_data:tensor([[4, 5, 6],

        [7, 8, 9],

        [7, 8, 9],

        [7, 8, 9]])  label: tensor([55, 66, 66, 66])

 batch:3 x_data:tensor([[1, 2, 3],

        [1, 2, 3],

        [7, 8, 9],

        [4, 5, 6]])  label: tensor([44, 44, 66,python 55])

注意:TensorDataset 中的参数必须是 tensor

Pytorch中TensorDataset的快速使用

Pytorch中,TensorDataset()可以快速构建训练所用的数据,不用使用自建的Mydataset(),如果没有熟悉适用的dataset可以使用TensorDataset()作为暂时替代。

只需要把data和label作为参数输入,就可以快速构建,之后便可以用Dataloader处理。

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

0

精彩评论

暂无评论...
验证码 换一张
取 消

关注公众号