关于 pytorch TVloss 代码实现的一些疑惑
时间: 2020-08-17来源:V2EX
前景提要
网上看到普遍的答案是这个 class TVLoss(nn.Module): def __init__(self,TVLoss_weight=1): super(TVLoss,self).__init__() self.TVLoss_weight = TVLoss_weight def forward(self,x): batch_size = x.size()[0] h_x = x.size()[2] w_x = x.size()[3] count_h = self._tensor_size(x[:,:,1:,:]) count_w = self._tensor_size(x[:,:,:,1:]) h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size def _tensor_size(self,t): return t.size()[1]*t.size()[2]*t.size()[3]
这里给出的说的是β=2,且不支持变更. 所以按照这里给出的公式 https://blog.csdn.net/yexiaogu1104/article/details/88395475 β/2, 当β=2 那就是 1 也就是不进行任何操作. 所以最后 return 这里为什么会返回一个 self.TVLoss_weight 2, 为啥要 2 呢..

科技资讯:

科技学院:

科技百科:

科技书籍:

网站大全:

软件大全:

热门排行