跑程序时总是会遇到显存不足的问题
查阅资料之后在代码开头加了这两行
torch.cuda.set_per_process_memory_fraction(0.5, 0)torch.cuda.empty_cache()
比原先能多跑一会儿,但很快又说显存不足了,于是我试了一下在每次训练迭代前清一下缓存,即
def train_iter(self):# ...for epoch in range(config.n_epochs): # 每次迭代开始前torch.cuda.empty_cache() # 新增代码#...
问题解决