示例#1
0
文件: trainer.py 项目: zxs789/Obj-GAN
    def build_models(self):
        netG = G_NET(len(self.cats_index_dict))
        netINSD = INS_D_NET(len(self.cats_index_dict))
        netGLBD = GLB_D_NET(len(self.cats_index_dict))

        netG.apply(weights_init)
        netINSD.apply(weights_init)
        netGLBD.apply(weights_init)

        if cfg.CUDA:
            netG.cuda()
            netINSD.cuda()
            netGLBD.cuda()

            if len(cfg.GPU_IDS) > 1:
                netG = nn.DataParallel(netG)
                netG.to(self.device)
                netINSD = nn.DataParallel(netINSD)
                netINSD.to(self.device)
                netGLBD = nn.DataParallel(netGLBD)
                netGLBD.to(self.device)

        # ########################################################### #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            filename = path_leaf(cfg.TRAIN.NET_G)
            istart = filename.rfind('_') + 1
            iend = filename.rfind('.')
            epoch = filename[istart:iend]
            epoch = int(epoch) + 1

            Gname = cfg.TRAIN.NET_G
            s_tmp = Gname[:Gname.rfind('/')]
            Dname = '%s/netINSD.pth' % (s_tmp)
            print('Load INSD from: ', Dname)
            state_dict = \
                torch.load(Dname, map_location=lambda storage, loc: storage)
            netINSD.load_state_dict(state_dict)

            s_tmp = Gname[:Gname.rfind('/')]
            Dname = '%s/netGLBD.pth' % (s_tmp)
            print('Load GLBD from: ', Dname)
            state_dict = \
                torch.load(Dname, map_location=lambda storage, loc: storage)
            netGLBD.load_state_dict(state_dict)

        return [netG, netINSD, netGLBD, epoch]
示例#2
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            raise Exception('Error: no pretrained text encoder')

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        logger.info('Load image encoder from: %s', img_encoder_path)
        image_encoder.eval()

        text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        logger.info('Load text encoder from: %s', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netsD = []
        from model import D_NET64, D_NET128, D_NET256
        netG = G_NET()
        if cfg.TREE.BRANCH_NUM > 0:
            netsD.append(D_NET64())
        if cfg.TREE.BRANCH_NUM > 1:
            netsD.append(D_NET128())
        if cfg.TREE.BRANCH_NUM > 2:
            netsD.append(D_NET256())

        netG.apply(weights_init)
        for i in range(len(netsD)):
            netsD[i].apply(weights_init)
        logger.info('# of params in netG: %s' % count_learnable_params(netG))
        logger.info('# of netsD: %s', len(netsD))
        logger.info('# of params in netsD: %s' % [count_learnable_params(netD) for netD in netsD])
        epoch = 0

        if self.resume:
            checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')])
            latest_checkpoint = checkpoint_list[-1]
            state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage)

            netG.load_state_dict(state_dict["netG"])
            for i in range(len(netsD)):
                netsD[i].load_state_dict(state_dict["netD"][i])
            epoch = int(latest_checkpoint[-8:-4]) + 1
            logger.info("Resuming training from checkpoint {} at epoch {}.".format(latest_checkpoint, epoch))

        #
        if cfg.TRAIN.NET_G != '':
            state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            logger.info('Load G from: %s', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    logger.info('Load D from: %s', Dname)
                    state_dict = torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        # ########################################################### #
        if cfg.CUDA:
            text_encoder.to(cfg.DEVICE)
            image_encoder.to(cfg.DEVICE)
            netG.to(cfg.DEVICE)
            if self.n_gpu > 1:
                netG = DataParallelPassThrough(netG, )
            for i in range(len(netsD)):
                netsD[i].to(cfg.DEVICE)
                if self.n_gpu > 1:
                    netsD[i] = DataParallelPassThrough(netsD[i], )
        return [text_encoder, image_encoder, netG, netsD, epoch]
示例#3
0
    def build_models(self):
        # ###################encoders######################################## #
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        # #######################generator and discriminators############## #
        netG = G_NET(len(self.cats_index_dict))
        netsPatD, netsShpD = [], []
        if cfg.TREE.BRANCH_NUM > 0:
            netsPatD.append(PAT_D_NET64())
            netsShpD.append(SHP_D_NET64(len(self.cats_index_dict)))
        if cfg.TREE.BRANCH_NUM > 1:
            netsPatD.append(PAT_D_NET128())
            netsShpD.append(SHP_D_NET128(len(self.cats_index_dict)))
        if cfg.TREE.BRANCH_NUM > 2:
            netsPatD.append(PAT_D_NET256())
            netsShpD.append(SHP_D_NET256(len(self.cats_index_dict)))

        netObjSSD = OBJ_SS_D_NET(len(self.cats_index_dict))
        netObjLSD = OBJ_LS_D_NET(len(self.cats_index_dict))

        netG.apply(weights_init)
        netObjSSD.apply(weights_init)
        netObjLSD.apply(weights_init)
        for i in range(len(netsPatD)):
            netsPatD[i].apply(weights_init)
            netsShpD[i].apply(weights_init)
        print('# of netsPatD', len(netsPatD))
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            netG.cuda()
            netObjSSD.cuda()
            netObjLSD.cuda()
            for i in range(len(netsPatD)):
                netsPatD[i].cuda()
                netsShpD[i].cuda()

            if len(cfg.GPU_IDS) > 1:
                text_encoder = nn.DataParallel(text_encoder)
                text_encoder.to(self.device)
                image_encoder = nn.DataParallel(image_encoder)
                image_encoder.to(self.device)
                netG = nn.DataParallel(netG)
                netG.to(self.device)
                netObjSSD = nn.DataParallel(netObjSSD)
                netObjSSD.to(self.device)
                netObjLSD = nn.DataParallel(netObjLSD)
                netObjLSD.to(self.device)
                for i in range(len(netsPatD)):
                    netsPatD[i] = nn.DataParallel(netsPatD[i])
                    netsPatD[i].to(self.device)

                    netsShpD[i] = nn.DataParallel(netsShpD[i])
                    netsShpD[i].to(self.device)
        #
        epoch = 0
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1

            Gname = cfg.TRAIN.NET_G
            for i in range(len(netsPatD)):
                s_tmp = Gname[:Gname.rfind('/')]

                Dname = '%s/netPatD%d.pth' % (s_tmp, i)
                print('Load PatD from: ', Dname)
                state_dict = \
                    torch.load(Dname, map_location=lambda storage, loc: storage)
                netsPatD[i].load_state_dict(state_dict)

                Dname = '%s/netShpD%d.pth' % (s_tmp, i)
                print('Load ShpD from: ', Dname)
                state_dict = \
                    torch.load(Dname, map_location=lambda storage, loc: storage)
                netsShpD[i].load_state_dict(state_dict)

            s_tmp = Gname[:Gname.rfind('/')]
            Dname = '%s/netObjSSD.pth' % (s_tmp)
            print('Load ObjSSD from: ', Dname)
            state_dict = \
                torch.load(Dname, map_location=lambda storage, loc: storage)
            netObjSSD.load_state_dict(state_dict)

            s_tmp = Gname[:Gname.rfind('/')]
            Dname = '%s/netObjLSD.pth' % (s_tmp)
            print('Load ObjLSD from: ', Dname)
            state_dict = \
                torch.load(Dname, map_location=lambda storage, loc: storage)
            netObjLSD.load_state_dict(state_dict)

        return [
            text_encoder, image_encoder, netG, netsPatD, netsShpD, netObjSSD,
            netObjLSD, epoch
        ]
示例#4
0
文件: trainer.py 项目: zxs789/Obj-GAN
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            netG = G_NET(len(self.cats_index_dict))
            netG.apply(weights_init)
            netG.eval()

            if cfg.CUDA:
                netG.cuda()

            if len(cfg.GPU_IDS) > 1:
                netG = nn.DataParallel(netG)
                netG.to(self.device)

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(
                torch.FloatTensor(batch_size, cfg.ROI.BOXES_NUM,
                                  len(self.cats_index_dict) * 4))
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0

            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)
                    # if step > 50:
                    #     break
                    imgs, pooled_hmaps, hmaps, bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps, \
                        rois, fm_rois, num_rois, class_ids, keys = prepare_data(data)
                    num_rois = num_rois.data.cpu().numpy()

                    cats_list = []
                    for batch_index in range(self.batch_size):
                        cats = []
                        for roi_index in range(num_rois[batch_index]):
                            rela_cat_id = int(rois[batch_index, roi_index, 4])
                            abs_cat_id = self.cats_dict[rela_cat_id][0]
                            cat = self.ixtoword[abs_cat_id].encode(
                                'ascii', 'ignore').decode('ascii')
                            cats.append(cat)
                        cats_list.append(cats)

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    max_num_roi = max(num_rois)
                    noise.data.normal_(0, 1)
                    fake_hmaps = netG(noise[:, :max_num_roi], bbox_maps_fwd,
                                      bbox_maps_bwd, bbox_fmaps)
                    fake_hmaps = fake_hmaps.repeat(1, 1, 3, 1, 1)
                    for j in range(batch_size):
                        s_tmp = '%s/single/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = 0
                        # for k in range(len(fake_imgs)):
                        im = fake_hmaps[j][k].data.cpu().numpy()

                        minV = im.min()
                        maxV = im.max()
                        im = (im - minV) / (maxV - minV)
                        im *= 255
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)

                        cat = cats_list[j][k]
                        fullpath = '{0}_{1}.png'.format(s_tmp, cat)
                        im.save(fullpath)
示例#5
0
    def build_models(self):
        # ############################## encoders ############################# #
        text_encoder = RNN_ENCODER(self.n_words,
                                   nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                   'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()

        # ########### image generator and (potential) shape generator ########## #
        netG = G_NET(len(self.cats_index_dict))
        netG.apply(weights_init)
        netG.eval()
        netShpG = None
        if cfg.TEST.USE_GT_BOX_SEG > 0:
            netShpG = SHP_G_NET(len(self.cats_index_dict))
            netShpG.apply(weights_init)
            netShpG.eval()

        # ################### parallization and initialization ################## #
        if cfg.CUDA:
            text_encoder.cuda()
            image_encoder.cuda()
            netG.cuda()
            if cfg.TEST.USE_GT_BOX_SEG > 0:
                netShpG.cuda()

            if len(cfg.GPU_IDS) > 1:
                text_encoder = nn.DataParallel(text_encoder)
                text_encoder.to(self.device)
                image_encoder = nn.DataParallel(image_encoder)
                image_encoder.to(self.device)
                netG = nn.DataParallel(netG)
                netG.to(self.device)

            if cfg.TEST.USE_GT_BOX_SEG > 0:
                netShpG = nn.DataParallel(netShpG)
                netShpG.to(self.device)

        state_dict = torch.load(cfg.TRAIN.NET_G,
                                map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        print('Load G from: ', cfg.TRAIN.NET_G)

        if cfg.TEST.USE_GT_BOX_SEG > 0:
            state_dict = torch.load(cfg.TEST.NET_SHP_G,
                                    map_location=lambda storage, loc: storage)
            netShpG.load_state_dict(state_dict)
            print('Load Shape G from: ', cfg.TEST.NET_SHP_G)

        return [text_encoder, image_encoder, netG, netShpG]