TungDuong commited on
Commit
e1eab30
·
verified ·
1 Parent(s): d47b48c

base model

Browse files
Files changed (1) hide show
  1. vgg19.py +43 -0
vgg19.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torchvision import models
5
+
6
+ class VGG19(nn.Module):
7
+ def __init__(self, required_grad=False):
8
+ super(VGG19, self).__init__()
9
+ self.required_grad = required_grad
10
+
11
+ self.vgg19 = models.vgg19(weights='IMAGENET1K_V1', progress=True)
12
+ self.feature_maps = list(self.vgg19.children())[0]
13
+ self.conv_layers = nn.Sequential(*self.feature_maps)
14
+
15
+ for layers, params in self.vgg19.named_parameters():
16
+ if not self.required_grad:
17
+ params.requires_grad = False
18
+
19
+ def forward(self, x, mode='style'):
20
+ feature_maps = []
21
+
22
+ if mode == 'style':
23
+ layers = [0, 5, 10, 19, 28]
24
+ for i in range(len(self.feature_maps)):
25
+ x = self.feature_maps[i](x)
26
+ if i in layers:
27
+ feature_maps.append(x)
28
+ return feature_maps
29
+
30
+ elif mode == 'content':
31
+ layer = 21
32
+ for i in range(len(self.feature_maps)):
33
+ x = self.feature_maps[i](x)
34
+ if i == layer:
35
+ return x
36
+
37
+ def get_feature_maps(self, image):
38
+ feature_maps = []
39
+ for i in range(len(self.conv_layers)):
40
+ image = self.conv_layers[i](image)
41
+ if type(self.conv_layers[i]) == nn.Conv2d:
42
+ feature_maps.append(image)
43
+ return feature_maps