def __init__(self, layer='r31', std=1., mean=0.): super(StyleAugmentation, self).__init__() # Open - Load with open('features.p', 'rb') as handle: self.features, self.means = pickle.load(handle) self.size = len(self.features) print("number of style available: ", self.size) self.matrix = MulLayer('r31') self.vgg = encoder3() self.dec = decoder3() self.vgg.load_state_dict(torch.load('models/vgg_' + layer + '.pth')) self.dec.load_state_dict(torch.load('models/dec_' + layer + '.pth')) self.matrix.load_state_dict(torch.load('models/' + layer + '.pth')) self.dist = torch.distributions.normal.Normal(torch.tensor([mean]), torch.tensor([std]))
def load_model(self): # MODEL if (self.opt.layer == 'r31'): self.vgg = encoder3() self.dec = decoder3() elif (self.opt.layer == 'r41'): self.vgg = encoder4() self.dec = decoder4() self.matrix = MulLayer(layer=self.opt.layer) self.vgg.load_state_dict(torch.load(self.opt.vgg_dir)) self.dec.load_state_dict(torch.load(self.opt.decoder_dir)) self.matrix.load_state_dict( torch.load(self.opt.matrix_dir, map_location=self.device)) self.vgg.to(self.device) self.dec.to(self.device) self.matrix.to(self.device)
num_workers=1, drop_last=True) content_loader = iter(content_loader_) style_dataset = Dataset(opt.stylePath, opt.loadSize, opt.fineSize) style_loader_ = torch.utils.data.DataLoader(dataset=style_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=1, drop_last=True) style_loader = iter(style_loader_) ################# MODEL ################# vgg5 = loss_network() if (opt.layer == 'r31'): matrix = MulLayer('r31') vgg = encoder3() dec = decoder3() elif (opt.layer == 'r41'): matrix = MulLayer('r41') vgg = encoder4() dec = decoder4() vgg.load_state_dict(torch.load(opt.vgg_dir)) # dec.load_state_dict(torch.load(opt.decoder_dir)) vgg5.load_state_dict(torch.load(opt.loss_network_dir)) matrix.load_state_dict(torch.load(opt.matrixPath)) for param in vgg.parameters(): param.requires_grad = False for param in vgg5.parameters(): param.requires_grad = False for param in matrix.parameters(): param.requires_grad = False
parser.add_argument('--URST', action="store_true", help='use URST framework') parser.add_argument("--device", type=str, default="cuda", help="device") parser.add_argument('--resize', type=int, default=0, help='resize') ################# PREPARATIONS ################# args = parser.parse_args() args.cuda = torch.cuda.is_available() print_options(args) os.makedirs(args.outf, exist_ok=True) content_name = args.content.split("/")[-1].split(".")[0] style_name = args.style.split("/")[-1].split(".")[0] device = torch.device(args.device) ################# MODEL ################# if(args.layer == 'r31'): vgg = encoder3().to(device) dec = decoder3().to(device) elif(args.layer == 'r41'): vgg = encoder4().to(device) dec = decoder4().to(device) matrix = MulLayer(args.layer).to(device) vgg.load_state_dict(torch.load(args.vgg_dir)) dec.load_state_dict(torch.load(args.decoder_dir)) matrix.load_state_dict(torch.load(args.matrixPath)) PATCH_SIZE = args.patch_size PADDING = args.padding content_tf = test_transform(0, False) style_tf = test_transform(args.style_size, True)
def __init__(self): super(Transfer3, self).__init__() self.vgg_c = encoder3() self.vgg_s = encoder3() self.matrix = MulLayer(layer='r31') self.dec = decoder3()