def load_model(self): # MODEL if (self.opt.layer == 'r31'): self.vgg = encoder3() self.dec = decoder3() elif (self.opt.layer == 'r41'): self.vgg = encoder4() self.dec = decoder4() self.matrix = MulLayer(layer=self.opt.layer) self.vgg.load_state_dict(torch.load(self.opt.vgg_dir)) self.dec.load_state_dict(torch.load(self.opt.decoder_dir)) self.matrix.load_state_dict( torch.load(self.opt.matrix_dir, map_location=self.device)) self.vgg.to(self.device) self.dec.to(self.device) self.matrix.to(self.device)
def __init__(self, root): super(LinearStyleTransfer, self).__init__() self.vgg = encoder4() self.dec = decoder4() self.matrix = MulLayer("r41") self.vgg.load_state_dict( torch.load(root + "python_package/models/vgg_r41.pth", map_location="cpu")) self.dec.load_state_dict( torch.load(root + "python_package/models/dec_r41.pth", map_location="cpu")) self.matrix.load_state_dict( torch.load(root + "python_package/models/r41.pth", map_location="cpu"))
style_loader_ = torch.utils.data.DataLoader(dataset=style_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=1, drop_last=True) style_loader = iter(style_loader_) ################# MODEL ################# vgg5 = loss_network() if (opt.layer == 'r31'): matrix = MulLayer('r31') vgg = encoder3() dec = decoder3() elif (opt.layer == 'r41'): matrix = MulLayer('r41') vgg = encoder4() dec = decoder4() vgg.load_state_dict(torch.load(opt.vgg_dir)) # dec.load_state_dict(torch.load(opt.decoder_dir)) vgg5.load_state_dict(torch.load(opt.loss_network_dir)) matrix.load_state_dict(torch.load(opt.matrixPath)) for param in vgg.parameters(): param.requires_grad = False for param in vgg5.parameters(): param.requires_grad = False for param in matrix.parameters(): param.requires_grad = False # for param in dec.parameters(): # param.requires_grad = False ################# LOSS & OPTIMIZER #################
################# PREPARATIONS ################# args = parser.parse_args() args.cuda = torch.cuda.is_available() print_options(args) os.makedirs(args.outf, exist_ok=True) content_name = args.content.split("/")[-1].split(".")[0] style_name = args.style.split("/")[-1].split(".")[0] device = torch.device(args.device) ################# MODEL ################# if(args.layer == 'r31'): vgg = encoder3().to(device) dec = decoder3().to(device) elif(args.layer == 'r41'): vgg = encoder4().to(device) dec = decoder4().to(device) matrix = MulLayer(args.layer).to(device) vgg.load_state_dict(torch.load(args.vgg_dir)) dec.load_state_dict(torch.load(args.decoder_dir)) matrix.load_state_dict(torch.load(args.matrixPath)) PATCH_SIZE = args.patch_size PADDING = args.padding content_tf = test_transform(0, False) style_tf = test_transform(args.style_size, True) repeat = 15 if args.test_speed else 1 time_list = []