def load_network(gpus): """hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmx netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) """ #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxstart netG = G_NET_hmx() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxend netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) # TODO: if cfg.TREE.BRANCH_NUM > 5: for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) # print(netsD[i]) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model.eval() return netG, netsD, len(netsD), inception_model, count
def load_Dnet(self, gpus): if cfg.TRAIN.NET_D == '': print('Error: the path for morels is not found!') sys.exit(-1) else: netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) self.num_Ds = len(netsD) for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) for i in range(cfg.TREE.BRANCH_NUM): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load( '%snetD%d_70000.pth' % (cfg.TRAIN.NET_D, i), map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: for i in range(len(netsD)): netsD[i].cuda() return netsD
def load_network(gpus): netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) # TODO: if cfg.TREE.BRANCH_NUM > 5: for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) # print(netsD[i]) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) try: istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) except: last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Model' with open(last_run_dir + '/count.txt', 'r') as f: count = int(f.read()) count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model.eval() return netG, netsD, len(netsD), inception_model, count
def load_network(gpus): netEn_img = MLP_ENCODER_IMG() netEn_img.apply(weights_init) netEn_img = torch.nn.DataParallel(netEn_img, device_ids=gpus) print(netEn_img) netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) for i in xrange(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in xrange(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) if cfg.TRAIN.NET_MLP_IMG != '': state_dict = torch.load(cfg.TRAIN.NET_MLP_IMG) netEn_img.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_MLP_IMG) inception_model = INCEPTION_V3() if cfg.CUDA: netG.cuda() netEn_img = netEn_img.cuda() for i in xrange(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model.eval() return netG, netsD, netEn_img, inception_model, len(netsD), count
def load_network(gpus): netG_64 = G_NET_64() netG_128 = G_NET_128() netG_256 = G_NET_256() netG_64.apply(weights_init) netG_128.apply(weights_init) netG_256.apply(weights_init) netG_64 = torch.nn.DataParallel(netG_64, device_ids=gpus) netG_128 = torch.nn.DataParallel(netG_128, device_ids=gpus) netG_256 = torch.nn.DataParallel(netG_256, device_ids=gpus) print(netG_256) netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) #multi GPU setting print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G_64 != '': state_dict = torch.load(cfg.TRAIN.NET_G_64) #load G network netG_64.load_state_dict(state_dict) #load model recommand print('Load ', cfg.TRAIN.NET_G_64) #visualize network if cfg.TRAIN.NET_G_128 != '': state_dict = torch.load(cfg.TRAIN.NET_G_128) # load G network netG_128.load_state_dict(state_dict) # load model recommand print('Load ', cfg.TRAIN.NET_G_128) # visualize network #istart = cfg.TRAIN.NET_G.rfind('_') + 1 #字符串最后一次出现的位置(从右向左查询),如果没有匹配项 #iend = cfg.TRAIN.NET_G.rfind('.') #count = cfg.TRAIN.NET_G[istart:iend] ######## netG_2000.pth #count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if cfg.CUDA: netG_64.cuda() netG_128.cuda() netG_256.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model.eval() return netG_64, netsD, len( netsD), inception_model, count, netG_128, netG_256
def load_network(gpus): netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) netsD = [] # 128 * 128 if cfg.TREE.BRANCH_NUM > 1: for i in range( 3): # 3 discriminators for background, parent and child stage netsD.append(D_NET128(i)) # 256 * 256 if cfg.TREE.BRANCH_NUM > 2: for i in range( 3): # 3 discriminators for background, parent and child stage netsD.append(D_NET256(i)) # for i in range(3): # 3 discriminators for background, parent and child stage # netsD.append(D_NET128(i)) for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) print(netsD[i]) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s_%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return netG, netsD, len(netsD), count
def build_models(self): print('Building models...') print('N_words: ', self.n_words) ##################### ## TEXT ENCODERS ## ##################### if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Built image encoder: ', image_encoder) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Built text encoder: ', text_encoder) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() ###################### ## CAPTION MODELS ## ###################### # cnn_encoder and rnn_encoder if cfg.CAP.USE_ORIGINAL: caption_cnn = CAPTION_CNN(embed_size=cfg.TEXT.EMBEDDING_DIM) caption_rnn = CAPTION_RNN(embed_size=cfg.TEXT.EMBEDDING_DIM, hidden_size=cfg.CAP.HIDDEN_SIZE, vocab_size=self.n_words, num_layers=cfg.CAP.NUM_LAYERS) else: caption_cnn = Encoder() caption_rnn = Decoder(idx2word=self.ixtoword) caption_cnn_checkpoint = torch.load( cfg.CAP.CAPTION_CNN_PATH, map_location=lambda storage, loc: storage) caption_rnn_checkpoint = torch.load( cfg.CAP.CAPTION_RNN_PATH, map_location=lambda storage, loc: storage) caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict']) caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict']) for p in caption_cnn.parameters(): p.requires_grad = False print('Load caption model from: ', cfg.CAP.CAPTION_CNN_PATH) caption_cnn.eval() for p in caption_rnn.parameters(): p.requires_grad = False print('Load caption model from: ', cfg.CAP.CAPTION_RNN_PATH) ################################# ## GENERATOR & DISCRIMINATOR ## ################################# netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) text_encoder = text_encoder.to(cfg.DEVICE) image_encoder = image_encoder.to(cfg.DEVICE) caption_cnn = caption_cnn.to(cfg.DEVICE) caption_rnn = caption_rnn.to(cfg.DEVICE) netG.to(cfg.DEVICE) for i in range(len(netsD)): netsD[i].to(cfg.DEVICE) return [ text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch ]
def build_models(self): # text encoders if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() # self.n_words = 156 text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # Caption models - cnn_encoder and rnn_decoder caption_cnn = CAPTION_CNN(cfg.CAP.embed_size) caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage)) for p in caption_cnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_cnn_path) caption_cnn.eval() # self.n_words = 9 caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers) # caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers) caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage)) for p in caption_rnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_rnn_path) # Generator and Discriminator: netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] # print(epoch) # print(state_dict.keys()) # print(netG.keys()) # epoch = state_dict['epoch'] epoch = int(epoch) + 1 # epoch = 187 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() caption_cnn = caption_cnn.cuda() caption_rnn = caption_rnn.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
def build_models(self): ################### Text and Image encoders ######################################## # if cfg.TRAIN.NET_E == '': # print('Error: no pretrained text-image encoders') # return VGG = VGGNet() for p in VGG.parameters(): p.requires_grad = False print("Load the VGG model") VGG.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) # when NET_E = '', train the image_encoder and text_encoder jointly if cfg.TRAIN.NET_E != '': state_dict = torch.load( cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage).state_dict() text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load( img_encoder_path, map_location=lambda storage, loc: storage).state_dict() image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() ####################### Generator and Discriminators ############## netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() if cfg.TRAIN.W_GAN: netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() netG.apply(weights_init) if cfg.TRAIN.W_GAN: if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() VGG = VGG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch, VGG]
def build_models(self): ################### Text and Image encoders ######################################## if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return """ image_encoder = CNN_dummy() image_encoder.cuda() image_encoder.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() """ VGG = VGG16() VGG.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() ####################### Generator and Discriminators ############## from model import D_NET256 netD = D_NET256() netG = EncDecNet() netD.apply(weights_init) netG.apply(weights_init) # epoch = 0 """ if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) """ # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() netG.cuda() netD.cuda() VGG.cuda() return [text_encoder, netG, netD, epoch, VGG]
def load_network(gpus:list, distributed:bool): netG = G_NET() netG.apply(weights_init) if distributed: netG = netG.cuda() netG = torch.nn.parallel.DistributedDataParallel(netG, device_ids=gpus, output_device=gpus[0], broadcast_buffers=True) else: if cfg.CUDA: netG = netG.cuda() netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) # TODO: if cfg.TREE.BRANCH_NUM > 5: # netsD_module = nn.ModuleList(netsD) # netsD_module.apply(weights_init) # netsD_module = torch.nn.parallel.DistributedDataParallel(netsD_module.cuda(), device_ids=gpus, output_device=gpus[0]) for i in range(len(netsD)): netsD[i].apply(weights_init) if distributed: netsD[i] = torch.nn.parallel.DistributedDataParallel(netsD[i].cuda(), device_ids=gpus, output_device=gpus[0], broadcast_buffers=True # , process_group=pg_Ds[i] ) else: netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) print(netsD[i]) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i), map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if not distributed: if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model = torch.nn.DataParallel(inception_model, device_ids=gpus) else: inception_model = torch.nn.parallel.DistributedDataParallel(inception_model.cuda(), device_ids=gpus, output_device=gpus[0]) pass # inception_model = inception_model.cpu() #to(torch.device("cuda:{}".format(gpus[0]))) inception_model.eval() print("model device, G:{}, D:{}, incep:{}".format(netG.device_ids, netsD[0].device_ids, inception_model.device_ids)) return netG, netsD, len(netsD), inception_model, count
def build_models(self): ##feature extractor inception_model = CNN_ENCODER() #inception_model = CNN_dummy() ## classifier networks classifiers = [] for i in range(cfg.ATT_NUM): cls_model = Att_Classifier() classifiers.append(cls_model) # #######################generator and discriminators############## # netsG = [] netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 """ if not self.args.kl_loss: netG = G_NET_not_CA(self.args) else: netG = G_NET() """ if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) netsG.append(G_NET_not_CA_stage1(self.args)) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) netsG.append(G_NET_not_CA_stage2(self.args)) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netsG.append(G_NET_not_CA_stage3(self.args)) # TODO: if cfg.TREE.BRANCH_NUM > 3: for i in range(len(netsG)): netsG[i].apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: inception_model.cuda() for i in range(len(netsG)): netsG[i].cuda() netsD[i].cuda() for i in range(len(classifiers)): classifiers[i].cuda() return [netsG, netsD, inception_model, classifiers, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 # MODIFIED if cfg.PRETRAINED_G != '': state_dict = torch.load(cfg.PRETRAINED_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.PRETRAINED_G) if cfg.TRAIN.B_NET_D: Gname = cfg.PRETRAINED_G s_tmp = Gname[:Gname.rfind('/')] for i in range(len(netsD)): Dname = '%s/netD%d.pth' % ( s_tmp, i ) # the name of Ds should be consistent and differ from each other in i print('Load D from: ', Dname) state_dict = torch.load( Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return # vgg16 network style_loss = VGGNet() for p in style_loss.parameters(): p.requires_grad = False print("Load the style loss model") style_loss.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # Create a target network. target_netG = deepcopy(netG) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() style_loss = style_loss.cuda() # The target network is stored on the scondary GPU.--------------------------------- target_netG.cuda(secondary_device) target_netG.ca_net.device = secondary_device #----------------------------------------------------------------------------------- netG.cuda() for i in range(len(netsD)): netsD[i] = netsD[i].cuda() # Disable training in the target network: for p in target_netG.parameters(): p.requires_grad = False return [ text_encoder, image_encoder, netG, target_netG, netsD, epoch, style_loss ]
def build_models(self): def count_parameters(model): total_param = 0 for name, param in model.named_parameters(): if param.requires_grad: num_param = np.prod(param.size()) if param.dim() > 1: print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param) else: print(name, ':', num_param) total_param += num_param return total_param # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: print('number of trainable parameters =', count_parameters(netG)) print('number of trainable parameters =', count_parameters(netsD[-1])) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': raise FileNotFoundError( 'No pretrained text encoder found in directory DAMSMencoders/. \n' + 'Please train the DAMSM first before training the GAN (see README for details).' ) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() if self.text_encoder_type == 'rnn': text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) elif self.text_encoder_type == 'transformer': text_encoder = GPT2Model.from_pretrained(TRANSFORMER_ENCODER) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] elif cfg.GAN.B_STYLEGEN: netG = G_NET_STYLED() if cfg.GAN.B_STYLEDISC: from model import D_NET_STYLED64, D_NET_STYLED128, D_NET_STYLED256 if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET_STYLED64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET_STYLED128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET_STYLED256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: else: from model import D_NET64, D_NET128, D_NET256 if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print(netG.__class__) for i in netsD: print(i.__class__) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) if cfg.GAN.B_STYLEGEN: netG.w_ewma = state_dict['w_ewma'] if cfg.CUDA: netG.w_ewma = netG.w_ewma.to('cuda:' + str(cfg.GPU_ID)) netG.load_state_dict(state_dict['netG_state_dict']) else: netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs text_encoder = nn.DataParallel(text_encoder) image_encoder = nn.DataParallel(image_encoder) netG = nn.DataParallel(netG) for i in range(len(netsD)): netsD[i] = nn.DataParallel(netsD[i]) image_encoder.to(self.device) text_encoder.to(self.device) netG.to(self.device) for i in range(len(netsD)): netsD[i].to(self.device) # if cfg.CUDA and torch.cuda.is_available(): # text_encoder = text_encoder.cuda() # image_encoder = image_encoder.cuda() # netG.cuda() # for i in range(len(netsD)): # netsD[i].cuda() # if cfg.PARALLEL: # netG = torch.nn.DataParallel(netG, device_ids=[0, 1, 2]) # text_encoder = torch.nn.DataParallel(text_encoder, device_ids=[0, 1, 2]) # image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[0, 1, 2]) # for i in range(len(netsD)): # netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=[0, 1, 2]) return [text_encoder, image_encoder, netG, netsD, epoch]
def genDiscOutputs(self, split_dir, num_samples=57140): if cfg.TRAIN.NET_G == '': logger.error('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.to(cfg.DEVICE) netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) ###HACK state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) text_encoder = text_encoder.to(cfg.DEVICE) text_encoder.eval() logger.info('Loaded text encoder from: %s', cfg.TRAIN.NET_E) batch_size = self.batch_size[0] nz = cfg.GAN.GLOBAL_Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)).to(cfg.DEVICE) local_noise = Variable( torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE) model_dir = cfg.TRAIN.NET_G state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) for keys in state_dict.keys(): print(keys) logger.info('Load G from: %s', model_dir) max_objects = 3 from model import D_NET256 netD = D_NET256() netD.load_state_dict(state_dict["netD"][2]) netD.eval() # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1] save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir) mkdir_p(save_dir) logger.info("Saving images to: {}".format(save_dir)) number_batches = num_samples // batch_size if number_batches < 1: number_batches = 1 data_iter = iter(self.data_loader) real_labels, fake_labels, match_labels = self.prepare_labels() for step in tqdm(range(number_batches)): data = data_iter.next() imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data( data, eval=True) transf_matrices = transformation_matrices[0] transf_matrices_inv = transformation_matrices[1] hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) local_noise.data.normal_(0, 1) inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, max_objects) inputs = tuple( (inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor ) else inp) for inp in inputs) with torch.no_grad(): fake_imgs, _, mu, logvar = netG(*inputs) inputs = (fake_imgs, fake_labels, transf_matrices, transf_matrices_inv, max_objects) codes = netsD[-1].partial_forward(*inputs)
def build_models(self): # ################### models ######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return if cfg.TRAIN.NET_G == '': print('Error: no pretrained main module') return VGG = VGGNet() for p in VGG.parameters(): p.requires_grad = False print("Load the VGG model") VGG.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() if cfg.GAN.B_DCGAN: netG = G_DCGAN() from model import D_NET256 as D_NET netD = D_NET(b_jcu=False) else: from model import D_NET256 netG = G_NET() netD = D_NET256() netD.apply(weights_init) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) netG.eval() print('Load G from: ', cfg.TRAIN.NET_G) epoch = 0 netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) istart = cfg.TRAIN.NET_C.rfind('_') + 1 iend = cfg.TRAIN.NET_C.rfind('.') epoch = cfg.TRAIN.NET_C[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.NET_D != '': state_dict = \ torch.load(cfg.TRAIN.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load DCM Discriminator from: ', cfg.TRAIN.NET_D) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() netDCM.cuda() VGG = VGG.cuda() netD.cuda() return [text_encoder, image_encoder, netG, netD, epoch, VGG, netDCM]
def build_models(self): # ###################encoders######################################## # image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) image_encoder.train() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.PRETRAINED_CNN: image_encoder_params = torch.load( cfg.PRETRAINED_CNN, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(image_encoder_params) if cfg.PRETRAINED_G != '': state_dict = torch.load(cfg.PRETRAINED_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.PRETRAINED_G) if cfg.TRAIN.B_NET_D: Gname = cfg.PRETRAINED_G s_tmp = Gname[:Gname.rfind('/')] for i in range(len(netsD)): Dname = '%s/netD%d.pth' % ( s_tmp, i ) # the name of Ds should be consistent and differ from each other in i print('Load D from: ', Dname) state_dict = torch.load( Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [image_encoder, netG, netsD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return #################################################################### image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): # make image encoder grad on p.requires_grad = True for k, v in image_encoder.named_children( ): # freeze the layer1-5 (set eval for BNlayer) if k in frozen_list_image_encoder: v.train(False) v.requires_grad_(False) print('Load image encoder from:', img_encoder_path) # image_encoder.eval() ################################################################### text_encoder = TEXT_TRANSFORMER_ENCODERv2( emb=cfg.TEXT.EMBEDDING_DIM, heads=8, depth=1, seq_length=cfg.TEXT.WORDS_NUM, num_tokens=self.n_words) # state_dict = torch.load(cfg.TRAIN.NET_E) # text_encoder.load_state_dict(state_dict) # print('Load ', cfg.TRAIN.NET_E) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = True print('Load text encoder from:', cfg.TRAIN.NET_E) # text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################## # # config = Config() cap_model = caption.build_model_v3(config) print("Initializing from Checkpoint...") cap_model_path = cfg.TRAIN.NET_E.replace('text_encoder', 'cap_model') if os.path.exists(cap_model_path): print('Load C from: {0}'.format(cap_model_path)) state_dict = \ torch.load(cap_model_path, map_location=lambda storage, loc: storage) cap_model.load_state_dict(state_dict['model']) else: base_line_path = 'catr/checkpoints/catr_damsm256_proj_coco2014_ep02.pth' print('Load C from: {0}'.format(base_line_path)) checkv3 = torch.load(base_line_path, map_location='cpu') cap_model.load_state_dict(checkv3['model'], strict=False) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() cap_model = cap_model.cuda() # caption transformer added netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch, cap_model]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if self.resume: checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')]) latest_checkpoint = checkpoint_list[-1] state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) for i in range(len(netsD)): netsD[i].load_state_dict(state_dict["netD"][i]) epoch = int(latest_checkpoint[-8:-4]) + 1 print("Resuming training from checkpoint {} at epoch {}.".format(latest_checkpoint, epoch)) # if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': self.logger.error('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM, condition=cfg.TRAIN.MASK_COND, condition_channel=1) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False self.logger.info( 'Load image encoder from: {}'.format(img_encoder_path)) image_encoder.eval() if self.audio_flag: text_encoder = CNNRNN_Attn(n_filters=40, nhidden=cfg.TEXT.EMBEDDING_DIM, nsent=cfg.TEXT.SENT_EMBEDDING_DIM) else: text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM, nsent=cfg.TEXT.SENT_EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False self.logger.info('Load text encoder from: {}'.format(cfg.TRAIN.NET_E)) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) self.logger.info('# of netsD: {}'.format(len(netsD))) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) self.logger.info('Load G from: {}'.format(cfg.TRAIN.NET_G)) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) self.logger.info('Load D from: {}'.format(Dname)) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]