def load_network(gpus): """hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmx netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) """ #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxstart netG = G_NET_hmx() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxend netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) # TODO: if cfg.TREE.BRANCH_NUM > 5: for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) # print(netsD[i]) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model.eval() return netG, netsD, len(netsD), inception_model, count
def load_Dnet(self, gpus): if cfg.TRAIN.NET_D == '': print('Error: the path for morels is not found!') sys.exit(-1) else: netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) self.num_Ds = len(netsD) for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) for i in range(cfg.TREE.BRANCH_NUM): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load( '%snetD%d_70000.pth' % (cfg.TRAIN.NET_D, i), map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: for i in range(len(netsD)): netsD[i].cuda() return netsD
def load_network(gpus): netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) # TODO: if cfg.TREE.BRANCH_NUM > 5: for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) # print(netsD[i]) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) try: istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) except: last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Model' with open(last_run_dir + '/count.txt', 'r') as f: count = int(f.read()) count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model.eval() return netG, netsD, len(netsD), inception_model, count
def load_network(gpus:list, distributed:bool): netG = G_NET() netG.apply(weights_init) if distributed: netG = netG.cuda() netG = torch.nn.parallel.DistributedDataParallel(netG, device_ids=gpus, output_device=gpus[0], broadcast_buffers=True) else: if cfg.CUDA: netG = netG.cuda() netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) # TODO: if cfg.TREE.BRANCH_NUM > 5: # netsD_module = nn.ModuleList(netsD) # netsD_module.apply(weights_init) # netsD_module = torch.nn.parallel.DistributedDataParallel(netsD_module.cuda(), device_ids=gpus, output_device=gpus[0]) for i in range(len(netsD)): netsD[i].apply(weights_init) if distributed: netsD[i] = torch.nn.parallel.DistributedDataParallel(netsD[i].cuda(), device_ids=gpus, output_device=gpus[0], broadcast_buffers=True # , process_group=pg_Ds[i] ) else: netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) print(netsD[i]) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i), map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if not distributed: if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model = torch.nn.DataParallel(inception_model, device_ids=gpus) else: inception_model = torch.nn.parallel.DistributedDataParallel(inception_model.cuda(), device_ids=gpus, output_device=gpus[0]) pass # inception_model = inception_model.cpu() #to(torch.device("cuda:{}".format(gpus[0]))) inception_model.eval() print("model device, G:{}, D:{}, incep:{}".format(netG.device_ids, netsD[0].device_ids, inception_model.device_ids)) return netG, netsD, len(netsD), inception_model, count