1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 神经网络加载数据 自建数据集 官方数据集 pytorch 显示数据集

神经网络加载数据 自建数据集 官方数据集 pytorch 显示数据集

时间:2022-03-19 16:03:20

相关推荐

神经网络加载数据 自建数据集 官方数据集 pytorch 显示数据集

1.官方的数据集 MNIST

使用torchvision.datasets 里面有很多数据集供选择

import torchimport torchvisionfrom torchvision import transforms, modelsbatch_size = 32 transform = pose([transforms.ToTensor(),transforms.Normalize(mean=(0.5),std=(0.5)),])train_data = torchvision.datasets.MNIST('./mn',train=True,download=True,transform=transform)data_loader_train = torch.utils.data.DataLoader(dataset=train_data,batch_size= batch_size,shuffle=True)test_data = torchvision.datasets.MNIST('./mn',train=False,download=True,transform=transform)data_loader_test = torch.utils.data.DataLoader(dataset=test_data,batch_size= batch_size,shuffle=True)

next(iter(data_loader_train)) # 用于查看数据

2.自建的数据集

读取单个数据文件

device=('cuda' if torch.cuda.is_available() else 'cpu')def load_img(image_path,transform=None,max_size=None,shape=None):image=Image.open(image_path)if max_size:scale=max_size/max(image.size)size=np.array(image.size)*scaleimage=image.resize(size.astype(int),Image.ANTIALIAS)if shape:image=image.resize(shape)if transform:image=transform(image).unsqueeze(0)return image.to(device)transform = pose([transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]),])content = load_img('image/content.jpg',transform,max_size=400)style = load_img('image/style.jpg',transform,max_size=400)

多张图片的情况 ImageFloder

这个时候需要把不同label 的数据放到不同的文件夹,ImageFolder 会自动加上标签,

from torchvison import datasetsdata_dir = './data'all_imgs=datasets.ImageFolder(os.path.join(data_dir,"train"),pose([transforms.RandomResizedCrop(input_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]))loader = torch.utils.data.DataLoader(all_imgs,batch_size=batch_size,shuffle=True)

img=next(iter(loader))[0]unloader=transforms.ToPILImage()def imshow(tensor,title=None):image=tensor.cpu().clone()image=image.squeeze(0)image=unloader(image)plt.imshow(image)if title is not None:plt.title(title)plt.pause(0.001)plt.figure()imshow(img[31],title='image')

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。