def build_models(self): netG = G_NET(len(self.cats_index_dict)) netINSD = INS_D_NET(len(self.cats_index_dict)) netGLBD = GLB_D_NET(len(self.cats_index_dict)) netG.apply(weights_init) netINSD.apply(weights_init) netGLBD.apply(weights_init) if cfg.CUDA: netG.cuda() netINSD.cuda() netGLBD.cuda() if len(cfg.GPU_IDS) > 1: netG = nn.DataParallel(netG) netG.to(self.device) netINSD = nn.DataParallel(netINSD) netINSD.to(self.device) netGLBD = nn.DataParallel(netGLBD) netGLBD.to(self.device) # ########################################################### # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) filename = path_leaf(cfg.TRAIN.NET_G) istart = filename.rfind('_') + 1 iend = filename.rfind('.') epoch = filename[istart:iend] epoch = int(epoch) + 1 Gname = cfg.TRAIN.NET_G s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netINSD.pth' % (s_tmp) print('Load INSD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netINSD.load_state_dict(state_dict) s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netGLBD.pth' % (s_tmp) print('Load GLBD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netGLBD.load_state_dict(state_dict) return [netG, netINSD, netGLBD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': raise Exception('Error: no pretrained text encoder') image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False logger.info('Load image encoder from: %s', img_encoder_path) image_encoder.eval() text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False logger.info('Load text encoder from: %s', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) logger.info('# of params in netG: %s' % count_learnable_params(netG)) logger.info('# of netsD: %s', len(netsD)) logger.info('# of params in netsD: %s' % [count_learnable_params(netD) for netD in netsD]) epoch = 0 if self.resume: checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')]) latest_checkpoint = checkpoint_list[-1] state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) for i in range(len(netsD)): netsD[i].load_state_dict(state_dict["netD"][i]) epoch = int(latest_checkpoint[-8:-4]) + 1 logger.info("Resuming training from checkpoint {} at epoch {}.".format(latest_checkpoint, epoch)) # if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) logger.info('Load G from: %s', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) logger.info('Load D from: %s', Dname) state_dict = torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder.to(cfg.DEVICE) image_encoder.to(cfg.DEVICE) netG.to(cfg.DEVICE) if self.n_gpu > 1: netG = DataParallelPassThrough(netG, ) for i in range(len(netsD)): netsD[i].to(cfg.DEVICE) if self.n_gpu > 1: netsD[i] = DataParallelPassThrough(netsD[i], ) return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netG = G_NET(len(self.cats_index_dict)) netsPatD, netsShpD = [], [] if cfg.TREE.BRANCH_NUM > 0: netsPatD.append(PAT_D_NET64()) netsShpD.append(SHP_D_NET64(len(self.cats_index_dict))) if cfg.TREE.BRANCH_NUM > 1: netsPatD.append(PAT_D_NET128()) netsShpD.append(SHP_D_NET128(len(self.cats_index_dict))) if cfg.TREE.BRANCH_NUM > 2: netsPatD.append(PAT_D_NET256()) netsShpD.append(SHP_D_NET256(len(self.cats_index_dict))) netObjSSD = OBJ_SS_D_NET(len(self.cats_index_dict)) netObjLSD = OBJ_LS_D_NET(len(self.cats_index_dict)) netG.apply(weights_init) netObjSSD.apply(weights_init) netObjLSD.apply(weights_init) for i in range(len(netsPatD)): netsPatD[i].apply(weights_init) netsShpD[i].apply(weights_init) print('# of netsPatD', len(netsPatD)) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() netObjSSD.cuda() netObjLSD.cuda() for i in range(len(netsPatD)): netsPatD[i].cuda() netsShpD[i].cuda() if len(cfg.GPU_IDS) > 1: text_encoder = nn.DataParallel(text_encoder) text_encoder.to(self.device) image_encoder = nn.DataParallel(image_encoder) image_encoder.to(self.device) netG = nn.DataParallel(netG) netG.to(self.device) netObjSSD = nn.DataParallel(netObjSSD) netObjSSD.to(self.device) netObjLSD = nn.DataParallel(netObjLSD) netObjLSD.to(self.device) for i in range(len(netsPatD)): netsPatD[i] = nn.DataParallel(netsPatD[i]) netsPatD[i].to(self.device) netsShpD[i] = nn.DataParallel(netsShpD[i]) netsShpD[i].to(self.device) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 Gname = cfg.TRAIN.NET_G for i in range(len(netsPatD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netPatD%d.pth' % (s_tmp, i) print('Load PatD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsPatD[i].load_state_dict(state_dict) Dname = '%s/netShpD%d.pth' % (s_tmp, i) print('Load ShpD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsShpD[i].load_state_dict(state_dict) s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netObjSSD.pth' % (s_tmp) print('Load ObjSSD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netObjSSD.load_state_dict(state_dict) s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netObjLSD.pth' % (s_tmp) print('Load ObjLSD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netObjLSD.load_state_dict(state_dict) return [ text_encoder, image_encoder, netG, netsPatD, netsShpD, netObjSSD, netObjLSD, epoch ]
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator netG = G_NET(len(self.cats_index_dict)) netG.apply(weights_init) netG.eval() if cfg.CUDA: netG.cuda() if len(cfg.GPU_IDS) > 1: netG = nn.DataParallel(netG) netG.to(self.device) batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable( torch.FloatTensor(batch_size, cfg.ROI.BOXES_NUM, len(self.cats_index_dict) * 4)) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break imgs, pooled_hmaps, hmaps, bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps, \ rois, fm_rois, num_rois, class_ids, keys = prepare_data(data) num_rois = num_rois.data.cpu().numpy() cats_list = [] for batch_index in range(self.batch_size): cats = [] for roi_index in range(num_rois[batch_index]): rela_cat_id = int(rois[batch_index, roi_index, 4]) abs_cat_id = self.cats_dict[rela_cat_id][0] cat = self.ixtoword[abs_cat_id].encode( 'ascii', 'ignore').decode('ascii') cats.append(cat) cats_list.append(cats) ####################################################### # (2) Generate fake images ###################################################### max_num_roi = max(num_rois) noise.data.normal_(0, 1) fake_hmaps = netG(noise[:, :max_num_roi], bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps) fake_hmaps = fake_hmaps.repeat(1, 1, 3, 1, 1) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = 0 # for k in range(len(fake_imgs)): im = fake_hmaps[j][k].data.cpu().numpy() minV = im.min() maxV = im.max() im = (im - minV) / (maxV - minV) im *= 255 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) cat = cats_list[j][k] fullpath = '{0}_{1}.png'.format(s_tmp, cat) im.save(fullpath)
def build_models(self): # ############################## encoders ############################# # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() # ########### image generator and (potential) shape generator ########## # netG = G_NET(len(self.cats_index_dict)) netG.apply(weights_init) netG.eval() netShpG = None if cfg.TEST.USE_GT_BOX_SEG > 0: netShpG = SHP_G_NET(len(self.cats_index_dict)) netShpG.apply(weights_init) netShpG.eval() # ################### parallization and initialization ################## # if cfg.CUDA: text_encoder.cuda() image_encoder.cuda() netG.cuda() if cfg.TEST.USE_GT_BOX_SEG > 0: netShpG.cuda() if len(cfg.GPU_IDS) > 1: text_encoder = nn.DataParallel(text_encoder) text_encoder.to(self.device) image_encoder = nn.DataParallel(image_encoder) image_encoder.to(self.device) netG = nn.DataParallel(netG) netG.to(self.device) if cfg.TEST.USE_GT_BOX_SEG > 0: netShpG = nn.DataParallel(netShpG) netShpG.to(self.device) state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) if cfg.TEST.USE_GT_BOX_SEG > 0: state_dict = torch.load(cfg.TEST.NET_SHP_G, map_location=lambda storage, loc: storage) netShpG.load_state_dict(state_dict) print('Load Shape G from: ', cfg.TEST.NET_SHP_G) return [text_encoder, image_encoder, netG, netShpG]