1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > Only tensors or tuples of tensors can be output from traced functions错误解决

Only tensors or tuples of tensors can be output from traced functions错误解决

时间:2022-05-07 10:16:22

相关推荐

Only tensors or tuples of tensors can be output from traced functions错误解决

(TorchScript应用) PyTorch模型转换为Torch脚本的代码出错。

出现原因:想将pytorch训练的.pth文件转成C++能处理的.pt文件。用的TorchScript的方法。具体代码如下:

import argparse

import cv2

import torchvision

import numpy as np

import torch

import torch.nn.functional as F

from torchvision import transforms

from PIL import Image

from tqdm import tqdm

from unet import NestedUNet

from unet import UNet

from utils.dataset import BasicDataset

from config import UNetConfig

cfg = UNetConfig()

device = torch.device('cpu')

model = eval(cfg.model)(cfg)

path = 'data/checkpoints/epoch_9.pth'#自己模型训练的结果

model.load_state_dict(torch.load(path, map_location=device))

model.to(device=device)

print(model)img_path = 'data/00000.jpg'#自己模型准备用的数据

img = cv2.imread(img_path)

imgXX = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))

img1 = torch.from_numpy(BasicDataset.preprocess(imgXX, cfg.scale))

img1 = img1.unsqueeze(0)

img1 = img1.to(device=device, dtype=torch.float32)

traced_script_module = torch.jit.trace(model,img1)#报错的代码位置

traced_script_module.save("torch_script_eval.pt")

报错的内容为:Only tensors or tuples of tensors can be output from traced functions

返回的是字典,是不支持的。想办法把返回的字典变成 tensors,就可以了。

怀疑model这个模型的返回值非法了。找到model模型定义的返回值位置

def forward(self, input):

x0_0 = self.conv0_0(input)

x1_0 = self.conv1_0(self.pool(x0_0))

x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

x2_0 = self.conv2_0(self.pool(x1_0))

x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))

x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

x3_0 = self.conv3_0(self.pool(x2_0))

x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))

x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))

x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

x4_0 = self.conv4_0(self.pool(x3_0))

x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))

x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))

x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))

x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

if self.deepsupervision:

output1 = self.final1(x0_1)

output2 = self.final2(x0_2)

output3 = self.final3(x0_3)

output4 = self.final4(x0_4)

return [output1, output2, output3, output4)] #按程序设定 走的是这个分支。 显然[ ]不符合。

else:

output = self.final(x0_4)

return output

将其改为return (output1, output2, output3, output4))#报错解决了。

总结:raced_script_module = torch.jit.trace(model,img1)这个位置报这类错,就怀疑返回值问题,然后一步步查。

上面这个改可能不符合代码原意了。再看看

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。