def load_dataset(self,opt): print(util.magenta("loading training data...")) train_data = data.DatasetPretrain(opt,load_test=False) self.train_loader = data.setup_loader(opt,train_data,shuffle=True) print(util.magenta("loading test data...")) test_data = data.DatasetPretrain(opt,load_test=True) self.test_loader = data.setup_loader(opt,test_data,shuffle=False)
def restore_checkpoint(self,opt): if opt.from_epoch!=0: util.restore_checkpoint_from_epoch(opt,self,["network","optim","sched"]) elif opt.load is not None: util.restore_checkpoint(opt,self,opt.load,["network","optim","sched"]) elif opt.imagenet_enc or opt.pretrained_dec is not None: pass else: print(util.magenta("training from scratch..."))
def load_pretrained_decoder(self, opt): print( util.magenta("loading pretrained decoder ({})...".format( opt.pretrained_dec))) weight_dict = torch.load(opt.pretrained_dec, map_location=opt.device) # remove "decoder/" prefix in dictionary decoder_weight_dict = { k[8:]: weight_dict[k] for k in weight_dict if "decoder" in k } self.decoder.load_state_dict(decoder_weight_dict)
def resnet18(pretrained=False, **kwargs): model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) if pretrained: print(util.magenta("loading pretrained encoder...")) weight_dict = model_zoo.load_url( "https://download.pytorch.org/models/resnet18-5c106cde.pth") block_names = list(set([k.split(".")[0] for k in weight_dict.keys()])) for b in block_names: if b == "fc": continue block_weight_dict = { ".".join(k.split(".")[1:]): weight_dict[k] for k in weight_dict if k[:len(b)] == b } getattr(model, b).load_state_dict(block_weight_dict) return model
import numpy as np import os, sys, time import torch import options, data, util import model print(util.yellow("=======================================================")) print(util.yellow("main.py (photometric mesh optimization)")) print(util.yellow("=======================================================")) print(util.magenta("setting configurations...")) opt = options.set() print(util.magenta("reading list of sequences...")) seq_list = data.load_sequence_list(opt, subset=1) seq_list = [("02958343", "eebbce8b77bdb53c82382fde2cafeb9")] with torch.cuda.device(opt.gpu): pmo = model.Model(opt) pmo.build_network(opt) pmo.restore_checkpoint(opt) print(util.yellow("======= OPTIMIZATION START =======")) for c, m in seq_list: pmo.load_sequence(opt, c, m) pmo.setup_visualizer(opt) pmo.setup_variables(opt) pmo.setup_optimizer(opt) pmo.time_start(opt) pmo.optimize(opt)
def build_network(self, opt): print(util.magenta("building AtlasNet...")) self.network = atlasnet.AtlasNet(opt, eval_enc=True, eval_dec=True) self.faces = self.network.faces_regular
def load_sequence(self, opt, c, m): print(util.magenta("loading sequence...")) self.sequence = data.load_sequence(opt, c, m) self.c, self.m = c, m
def build_network(self,opt): print(util.magenta("building AtlasNet...")) self.network = atlasnet.AtlasNet(opt)
import numpy as np import os, sys, time import torch import options, util import model_pretrain print(util.yellow("=======================================================")) print( util.yellow( "main_pretrain.py (pretraining with AtlasNet reimplementation)")) print(util.yellow("=======================================================")) print(util.magenta("setting configurations...")) opt = options.set() with torch.cuda.device(opt.gpu): trainer = model_pretrain.Model(opt) trainer.load_dataset(opt) trainer.build_network(opt) trainer.setup_optimizer(opt) trainer.restore_checkpoint(opt) trainer.setup_visualizer(opt) print(util.yellow("======= TRAINING START =======")) trainer.time_start(opt) for ep in range(opt.from_epoch, opt.to_epoch): trainer.train_epoch(opt, ep) if (ep + 1) % 10 == 0: trainer.evaluate(opt, ep) if (ep + 1) % 50 == 0: trainer.save_checkpoint(opt, ep) print(util.yellow("======= TRAINING DONE ======="))