1 2 3 4 5 6 7 8 9
| class Normalize(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406], dtype=torch.float64).view(1, -1, 1, 1)) self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225], dtype=torch.float64).view(1, -1, 1, 1)) self.register_buffer('norm', torch.tensor([255.0], dtype=torch.float64).view(1, 1, 1, 1))
def forward(self, images): return (images / self.norm - self.mean) / self.std
|