예제 #1
0
def load_network(gpus):
    """hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmx
    netG = G_NET()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)
    """
    #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxstart
    netG = G_NET_hmx()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)
    #hmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxhmxend

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    if cfg.TREE.BRANCH_NUM > 3:
        netsD.append(D_NET512())
    if cfg.TREE.BRANCH_NUM > 4:
        netsD.append(D_NET1024())
    # TODO: if cfg.TREE.BRANCH_NUM > 5:

    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
        # print(netsD[i])
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        count = cfg.TRAIN.NET_G[istart:iend]
        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    if cfg.CUDA:
        netG.cuda()
        for i in range(len(netsD)):
            netsD[i].cuda()
        inception_model = inception_model.cuda()
    inception_model.eval()

    return netG, netsD, len(netsD), inception_model, count
예제 #2
0
def load_network(gpus):
    netG = G_NET()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    if cfg.TREE.BRANCH_NUM > 3:
        netsD.append(D_NET512())
    if cfg.TREE.BRANCH_NUM > 4:
        netsD.append(D_NET1024())
    # TODO: if cfg.TREE.BRANCH_NUM > 5:

    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
        # print(netsD[i])
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        try:
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            count = cfg.TRAIN.NET_G[istart:iend]
            count = int(count)
        except:
            last_run_dir = cfg.DATA_DIR + '/' + cfg.LAST_RUN_DIR + '/Model'
            with open(last_run_dir + '/count.txt', 'r') as f:
                count = int(f.read())

        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    if cfg.CUDA:
        netG.cuda()
        for i in range(len(netsD)):
            netsD[i].cuda()
        inception_model = inception_model.cuda()
    inception_model.eval()

    return netG, netsD, len(netsD), inception_model, count
예제 #3
0
    def __init__(self, output_dir, data_loader, dataset):
        #if cfg.TRAIN.FLAG:
        self.model_dir = os.path.join(output_dir, 'Model')
        self.image_dir = os.path.join(output_dir, 'Image')
        self.snapshot_dir = os.path.join(output_dir, 'Snapshot')
        self.score_dir = os.path.join(output_dir, 'Score')
        mkdir_p(self.model_dir)
        mkdir_p(self.image_dir)
        mkdir_p(self.snapshot_dir)
        mkdir_p(self.score_dir)

        if len(cfg.GPU_IDS) == 1 and cfg.GPU_IDS[0] >= 0:
            torch.cuda.set_device(0)
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE
        self.display_interval = cfg.TRAIN.DISPLAY_INTERVAL
        self.device = torch.device("cuda" if cfg.CUDA else "cpu")

        self.n_words = dataset.n_words
        self.ixtoword = dataset.ixtoword
        self.cats_dict = dataset.cats_dict
        self.cats_index_dict = dataset.cats_index_dict
        self.cat_labels = dataset.cat_labels
        self.cat_label_lens = dataset.cat_label_lens
        self.sorted_cat_label_indices = dataset.sorted_cat_label_indices

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

        self.glove_emb = dataset.glove_embed
        if cfg.CUDA:
            self.glove_emb.cuda()
            if len(cfg.GPU_IDS) > 1:
                self.glove_emb = nn.DataParallel(self.glove_emb)
                self.glove_emb.to(self.device)
        self.glove_emb.eval()

        if cfg.TEST.USE_TF:
            import miscc.inception_score_tf as inception_score
            self.inception_score = inception_score
            torch.cuda.set_device(0)
        else:
            self.inception_model = INCEPTION_V3()
            block_idx = INCEPTION_V3_FID.BLOCK_INDEX_BY_DIM[cfg.TEST.FID_DIMS]
            self.inception_model_fid = INCEPTION_V3_FID([block_idx])
            if cfg.CUDA:
                self.inception_model.cuda()
                self.inception_model_fid.cuda()
                if len(cfg.GPU_IDS) > 1:
                    self.inception_model = nn.DataParallel(
                        self.inception_model)
                    self.inception_model.to(self.device)
                    self.inception_model_fid = nn.DataParallel(
                        self.inception_model_fid)
                    self.inception_model_fid.to(self.device)
            self.inception_model.eval()
            self.inception_model_fid.eval()
예제 #4
0
def load_network(gpus):
    netEn_img = MLP_ENCODER_IMG()
    netEn_img.apply(weights_init)
    netEn_img = torch.nn.DataParallel(netEn_img, device_ids=gpus)
    print(netEn_img)

    netG = G_NET()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())

    for i in xrange(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        count = cfg.TRAIN.NET_G[istart:iend]
        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in xrange(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)

    if cfg.TRAIN.NET_MLP_IMG != '':
        state_dict = torch.load(cfg.TRAIN.NET_MLP_IMG)
        netEn_img.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_MLP_IMG)

    inception_model = INCEPTION_V3()

    if cfg.CUDA:
        netG.cuda()
        netEn_img = netEn_img.cuda()
        for i in xrange(len(netsD)):
            netsD[i].cuda()
        inception_model = inception_model.cuda()
    inception_model.eval()

    return netG, netsD, netEn_img, inception_model, len(netsD), count
예제 #5
0
def load_network(gpus):
    netG_64 = G_NET_64()
    netG_128 = G_NET_128()
    netG_256 = G_NET_256()
    netG_64.apply(weights_init)
    netG_128.apply(weights_init)
    netG_256.apply(weights_init)
    netG_64 = torch.nn.DataParallel(netG_64, device_ids=gpus)
    netG_128 = torch.nn.DataParallel(netG_128, device_ids=gpus)
    netG_256 = torch.nn.DataParallel(netG_256, device_ids=gpus)
    print(netG_256)
    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i],
                                         device_ids=gpus)  #multi GPU setting
    print('# of netsD', len(netsD))
    count = 0
    if cfg.TRAIN.NET_G_64 != '':
        state_dict = torch.load(cfg.TRAIN.NET_G_64)  #load G network
        netG_64.load_state_dict(state_dict)  #load model recommand
        print('Load ', cfg.TRAIN.NET_G_64)  #visualize network
    if cfg.TRAIN.NET_G_128 != '':
        state_dict = torch.load(cfg.TRAIN.NET_G_128)  # load G network
        netG_128.load_state_dict(state_dict)  # load model recommand
        print('Load ', cfg.TRAIN.NET_G_128)  # visualize network
        #istart = cfg.TRAIN.NET_G.rfind('_') + 1   #字符串最后一次出现的位置(从右向左查询),如果没有匹配项
        #iend = cfg.TRAIN.NET_G.rfind('.')
        #count = cfg.TRAIN.NET_G[istart:iend]   ######## netG_2000.pth
        #count = int(count) + 1
    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)
    inception_model = INCEPTION_V3()
    if cfg.CUDA:
        netG_64.cuda()
        netG_128.cuda()
        netG_256.cuda()
        for i in range(len(netsD)):
            netsD[i].cuda()
        inception_model = inception_model.cuda()
    inception_model.eval()
    return netG_64, netsD, len(
        netsD), inception_model, count, netG_128, netG_256
예제 #6
0
    def __init__(self, output_dir, data_loader, dataset):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.score_dir = os.path.join(output_dir, 'Score')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.score_dir)

        if len(cfg.GPU_IDS) == 1 and cfg.GPU_IDS[0] >= 0:
            torch.cuda.set_device(0)
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
        self.print_interval = cfg.TRAIN.PRINT_INTERVAL
        self.display_interval = cfg.TRAIN.DISPLAY_INTERVAL

        self.n_words = dataset.n_words
        self.ixtoword = dataset.ixtoword
        self.cats_index_dict = dataset.cats_index_dict
        self.cat_labels = dataset.cat_labels
        self.cat_label_lens = dataset.cat_label_lens
        self.sorted_cat_label_indices = dataset.sorted_cat_label_indices

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

        self.device = torch.device("cuda" if cfg.CUDA else "cpu")

        self.inception_model = INCEPTION_V3()
        self.glove_emb = dataset.glove_embed
        if cfg.CUDA:
            self.inception_model.cuda()
            self.glove_emb.cuda()
            if len(cfg.GPU_IDS) > 1:
                self.inception_model = nn.DataParallel(self.inception_model)
                self.inception_model.to(self.device)
                self.glove_emb = nn.DataParallel(self.glove_emb)
                self.glove_emb.to(self.device)
        self.inception_model.eval()
        self.glove_emb.eval()
예제 #7
0
파일: trainer.py 프로젝트: JosephKJ/aRTISt
def load_network(gpus):
    # Generator
    netG = Generator()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    # Discriminator
    netD = Discriminator()
    netD.apply(weights_init)
    netD = torch.nn.DataParallel(netD, device_ids=gpus)
    print(netD)

    # Loading pretrained weights, if exists.
    training_iter = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Loaded Generator from saved model.', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        training_iter = cfg.TRAIN.NET_G[istart:iend]
        training_iter = int(training_iter) + 1

    if cfg.TRAIN.NET_D != '':
        print('Loading Discriminator from %s.pth' % (cfg.TRAIN.NET_D))
        state_dict = torch.load('%s.pth' % (cfg.TRAIN.NET_D))
        netD.load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    # Moving to GPU
    if cfg.CUDA:
        netG.cuda()
        netD.cuda()
        inception_model = inception_model.cuda()

    inception_model.eval()

    return netG, netD, inception_model, training_iter
예제 #8
0
    def evaluate(self, split_dir):
        inception_model = INCEPTION_V3()
        # fid_model = FID_INCEPTION()
        if cfg.CUDA:
            inception_model.cuda()
        #     fid_model.cuda()
        inception_model.eval()
        # fid_model.eval()

        if cfg.TRAIN.NET_G == '':
            print('Error: the path for models is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            # print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            # s_tmp = cfg.TRAIN.NET_G
            # istart = s_tmp.rfind('_') + 1
            # iend = s_tmp.rfind('.')
            # iteration = int(s_tmp[istart:iend])
            # s_tmp = s_tmp[:s_tmp.rfind('/')]
            # save_dir = '%s/iteration%d' % (s_tmp, iteration)
            # save_dir = 'C:\\Users\\alper\\PycharmProjects\\MSGAN\\StackGAN++-Mode-Seeking\\results'
            save_dir = "D:\\results"

            nz = cfg.GAN.Z_DIM
            n_samples = 50
            # noise = Variable(torch.FloatTensor(self.batch_size, nz))
            noise = Variable(torch.FloatTensor(n_samples, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(tqdm(self.data_loader)):
                # if step == 8:
                #     break
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                # batch_size = imgs[0].size(0)
                # noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                inception_score_list = []
                fid_list = []
                score_list = []
                predictions = []
                fids = []
                for i in range(embedding_dim):
                    inception_score_list.append([])
                    fid_list.append([])
                    score_list.append([])

                    emb_imgs = []
                    for j in range(n_samples):
                        noise_j = noise[j].unsqueeze(0)
                        t_embeddings_i = t_embeddings[:, i, :]
                        fake_imgs, _, _ = netG(noise_j, t_embeddings_i)
                        # filenames_number ='_sample_%2.2d'%(j)
                        # filenames_new = []
                        # filenames_new.append(filenames[-1]+filenames_number)
                        # filenames_new = tuple(filenames_new)

                        # for selecting reasonable images
                        pred = inception_model(fake_imgs[-1].detach())
                        pred = pred.data.cpu().numpy()
                        predictions.append(pred)
                        bird_indices = [
                            7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 21, 23,
                            81, 84, 85, 86, 88, 90, 91, 93, 94, 95, 96, 97, 99,
                            129, 130, 133, 134, 135, 138, 141, 142, 143, 144,
                            146, 517
                        ]
                        score = np.max(pred[0, bird_indices])
                        score_list[i].append((j, score))
                        emb_imgs.append(fake_imgs[2].data.cpu())
                        if cfg.TEST.B_EXAMPLE:
                            # fake_img_list.append(fake_imgs[0].data.cpu())
                            # fake_img_list.append(fake_imgs[1].data.cpu())
                            fake_img_list.append(fake_imgs[2].data.cpu())
                        else:
                            self.save_singleimages(fake_imgs[-1], filenames, j,
                                                   save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                    score_list[i] = sorted(score_list[i],
                                           key=lambda x: x[1],
                                           reverse=True)[:5]
                    # for FID score
                    # ffi = [i[0].numpy() for i in emb_imgs]
                    fake_filtered_images = [
                        fake_img_list[i][0].numpy()
                        for i in range(len(fake_img_list))
                    ]
                    img_dir = os.path.join(cfg.DATA_DIR, "CUB_200_2011",
                                           "images",
                                           filenames[0].split("/")[0])
                    img_files = [
                        os.path.join(img_dir, i) for i in os.listdir(img_dir)
                    ]

                    # act_real = get_activations(img_files, fid_model)
                    # mu_real, sigma_real = get_fid_stats(act_real)
                    # print("mu_real: {}, sigma_real: {}".format(mu_real, sigma_real))

                    np_imgs = np.array(fake_filtered_images)
                    # print(np_imgs.shape)

                    # # print(type(np_imgs[0]))
                    # act_fake = get_activations(np_imgs, fid_model, img=True)
                    # mu_fake, sigma_fake = get_fid_stats(act_fake)
                    # fid_score = frechet_distance(mu_real, sigma_real, mu_fake, sigma_fake)
                    # fids.append(fid_score)
                    # print("mu_fake: {}, sigma_fake: {}".format(mu_fake, sigma_fake))
                # print(inception_score_list)

                # # calculate inception score
                # predictions = np.concatenate(predictions, 0)
                # mean, std = compute_inception_score(predictions, 10)
                # mean_nlpp, std_nlpp = \
                #     negative_log_posterior_probability(predictions, 10)
                # inception_score_list.append((mean, std, mean_nlpp, std_nlpp))

                # # for FID score
                # fake_filtered_images = [fake_img_list[i*n_samples + k[0]][0].numpy() for i, j in enumerate(score_list) for k in j]
                # # fake_filtered_images = [fake_img_list[i][0].numpy() for i in range(len(fake_img_list))]
                # img_dir = os.path.join(cfg.DATA_DIR, "CUB_200_2011", "images", filenames[0].split("/")[0])
                # img_files = [os.path.join(img_dir, i) for i in os.listdir(img_dir)]
                #
                # act_real = get_activations(img_files, fid_model)
                # mu_real, sigma_real = get_fid_stats(act_real)
                # # print("mu_real: {}, sigma_real: {}".format(mu_real, sigma_real))
                #
                # np_imgs = np.array(fake_filtered_images)
                # # print(np_imgs.shape)
                #
                # # print(type(np_imgs[0]))
                # act_fake = get_activations(np_imgs, fid_model, img=True)
                # mu_fake, sigma_fake = get_fid_stats(act_fake)
                # # print("mu_fake: {}, sigma_fake: {}".format(mu_fake, sigma_fake))
                #
                # # fid_score = frechet_distance(mu_real, sigma_real, mu_fake, sigma_fake)
                # fid_score = np.mean(fids)
                # fid_list.append(fid_score)
                # stats = 'step: {}, FID: {}, inception_score: {}, nlpp: {}\n'.format(step, fid_score, (mean, std), (mean_nlpp, std_nlpp))
                # with open("results\\stats.txt", "a+") as f:
                #     f.write(stats)
                # print(stats)

                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    if cfg.TEST.FILTER:
                        images_to_save = [
                            fake_img_list[i * n_samples + k[0]]
                            for i, j in enumerate(score_list) for k in j
                        ]
                    else:
                        images_to_save = fake_img_list
                    self.save_superimages(images_to_save, filenames, save_dir,
                                          split_dir, 256)
예제 #9
0
def load_network(gpus):
    netEn = MLP_ENCODER(cfg.TEXT.DIMENSION_THETA1, cfg.TEXT.DIMENSION_THETA2,
                        cfg.TEXT.DIMENSION_THETA3)
    netEn.apply(weights_init)
    netEn = torch.nn.DataParallel(netEn, device_ids=gpus)
    print(netEn)

    if cfg.TRAIN.NET_MLP_EN != '':
        state_dict = torch.load(cfg.TRAIN.NET_MLP_EN)
        netEn.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_MLP_EN)

    netsG = []
    netsG.append(G_NET_64())
    netsG.append(G_NET_128())
    netsG.append(G_NET_256())
    for i in xrange(len(netsG)):
        netsG[i].apply(weights_init)
        netsG[i] = torch.nn.DataParallel(netsG[i], device_ids=gpus)

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_THETA3_NET16(cfg.TEXT.DIMENSION_THETA3))
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_THETA3_NET32(cfg.TEXT.DIMENSION_THETA3))
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_THETA3_NET64(cfg.TEXT.DIMENSION_THETA3))

    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_THETA2_NET32(cfg.TEXT.DIMENSION_THETA2))
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_THETA2_NET64(cfg.TEXT.DIMENSION_THETA2))
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_THETA2_NET128(cfg.TEXT.DIMENSION_THETA2))

    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_THETA1_NET64(cfg.TEXT.DIMENSION_THETA1))
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_THETA1_NET128(cfg.TEXT.DIMENSION_THETA1))
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_THETA1_NET256(cfg.TEXT.DIMENSION_THETA1))

    for i in xrange(len(netsD)):
        netsD[i].apply(weights_init)
        netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.COUNT != '':
        count = np.load('%s' % (cfg.TRAIN.COUNT))
    if cfg.TRAIN.NET_G != '':
        for i in xrange(len(netsG)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_G, i))
            state_dict = torch.load('%s%d' % (cfg.TRAIN.NET_G, i))
            netsG[i].load_state_dict(state_dict)

    if cfg.TRAIN.NET_D != '':
        for i in xrange(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i))
            netsD[i].load_state_dict(state_dict)

    if cfg.INCEPTION:
        inception_model = INCEPTION_V3()
        if cfg.CUDA:
            inception_model = inception_model.cuda()
        inception_model.eval()
    else:
        inception_model = None

    if cfg.CUDA:
        netEn.cuda()
        for i in xrange(len(netsG)):
            netsG[i].cuda()
        for i in xrange(len(netsD)):
            netsD[i].cuda()

    return netEn, netsG, netsD, len(netsD), inception_model, count
예제 #10
0
imsize = 256
image_transform = transforms.Compose([
        transforms.Resize(int(imsize * 76 / 64)),
        transforms.RandomCrop(imsize),
        transforms.RandomHorizontalFlip()])
dataset = TextDataset(DATA_DIR, split='test', base_size=64,
                              transform=image_transform)

nz = 100
n_samples = 10

fid_model = FID_INCEPTION()
fid_model.cuda()
fid_model.eval()

inception_model = INCEPTION_V3()
inception_model.cuda()
inception_model.eval()

G_NET_Path = 'C:\\Users\\alper\\PycharmProjects\\MSGAN\\StackGAN++-Mode-Seeking\\models\\ours_new.pth'
netG = G_NET()
netG.apply(weights_init)
torch.cuda.set_device(0)
netG = netG.cuda()
netG = torch.nn.DataParallel(netG, device_ids=[0])
state_dict = \
    torch.load(G_NET_Path,
               map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)

예제 #11
0
    def evaluate(self, split_dir):
        sample_num_per_image = 1
        self.inception_model = INCEPTION_V3().eval()

        if cfg.CUDA:
            self.inception_model = self.inception_model.cuda()
            # self.inception_model = torch.cuda.parallel.DistributedDataParallel(self.inception_model, device_ids=self.gpus, output_device=self.gpus[0])
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise_raw = (torch.FloatTensor(self.batch_size, nz).requires_grad_())
            if cfg.CUDA:
                netG.cuda()
                noise_raw = noise_raw.cuda()
            else:
                pass

            predictions_g = []
            predictions_r = []
            activate_g = []
            activate_r = []
            IS_mean_add = 0
            FID_add = 0
            NLPP_mean_add = 0
            metric_cnt = 0
            count = 0
            # switch to evaluate mode
            netG.eval()
            loader_bar = tqdm.tqdm(self.data_loader)
            for step, data in enumerate(loader_bar, 0):
                count += 1
                imgs, t_embeddings, filenames = data
                # print(t_embeddings.shape)
                if cfg.CUDA:
                    if isinstance(t_embeddings, list):
                        t_embeddings = [emb.requires_grad_(False).float().cuda() for emb in t_embeddings]
                    else:   
                        t_embeddings = t_embeddings.requires_grad_(False).float().cuda()
                    if isinstance(imgs, list):
                        imgs = [img.requires_grad_(False).float().cuda() for img in imgs]
                    else:
                        imgs = imgs.requires_grad_(False).float().cuda()
                else:
                    if isinstance(t_embeddings, list):
                        t_embeddings = [emb.requires_grad_(False).float() for emb in t_embeddings]
                    else:   
                        t_embeddings = t_embeddings.requires_grad_(False).float()
                    if isinstance(imgs, list):
                        imgs = [img.requires_grad_(False).float() for img in imgs]
                    else:
                        imgs = imgs.requires_grad_(False).float()
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise = noise_raw[:batch_size]
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)
                # trunc
                # noise[noise<-cfg.TEST.TRUNC] = -cfg.TEST.TRUNC
                # noise[noise>cfg.TEST.TRUNC] = cfg.TEST.TRUNC

                fake_img_list = []

                for i in range(embedding_dim):
                    for sample_idx in range(sample_num_per_image):
                        noise.data.normal_(0, 1)
                        with torch.autograd.no_grad():
                            fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                            real_imgs = imgs
                        # pred_g, pool3_g = self.inception_model(fake_imgs[-1].detach())
                        # pred_r, pool3_r = self.inception_model(real_imgs[-1].detach())
                    # predictions_g.append(pred_g.data.cpu().numpy())
                    # predictions_r.append(pred_r.data.cpu().numpy())
                    # activate_g.append(pool3_g.data.cpu().numpy())
                    # activate_r.append(pool3_r.data.cpu().numpy())

                        if cfg.TEST.B_EXAMPLE:
                            # fake_img_list.append(fake_imgs[0].data.cpu())
                            # fake_img_list.append(fake_imgs[1].data.cpu())
                            fake_img_list.append(fake_imgs[2].data.cpu())
                        else:
                            self.save_singleimages(fake_imgs[-1], filenames,
                                                   save_dir, split_dir, i, 256, sample_idx)
                            # self.save_singleimages(fake_imgs[-2], filenames,
                            #                        save_dir, split_dir, i, 128)
                            # self.save_singleimages(fake_imgs[-3], filenames,
                            #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames,
                                          save_dir, split_dir, 256)
            print("len pred: {}".format(len(predictions_g)))
            return [{'mu':0, 'sigma':0},{'mu':0, 'sigma':0}]  # to save time

            predictions_g = np.concatenate(predictions_g, 0)
            predictions_r = np.concatenate(predictions_r, 0)
            activate_g = np.concatenate(activate_g, 0)
            activate_r = np.concatenate(activate_r, 0)
            mean, std = compute_inception_score(predictions_g, 10)
            # print('mean:', mean, 'std', std)
            self.summary_writer.add_scalar('Inception_mean', mean, count)
            #
            mean_nlpp, std_nlpp = \
                negative_log_posterior_probability(predictions_g, 10)
            self.summary_writer.add_scalar('NLPP_mean', mean_nlpp, count)
            # FID
            fid, fid_data = compute_frethet_distance(activate_g, activate_r)
            self.summary_writer.add_scalar("FID", fid, count)
            IS_mean_add += mean
            FID_add += fid
            NLPP_mean_add += mean_nlpp
            metric_cnt += 1

            IS_mean, FID_mean, NLPP_mean = IS_mean_add/metric_cnt, FID_add/metric_cnt, NLPP_mean_add/metric_cnt
            print("total, IS mean:{}, FID mean:{}, NLPP mean:{}".format(IS_mean, FID_mean, NLPP_mean))

            return fid_data
예제 #12
0
def load_network(gpus:list, distributed:bool):
    netG = G_NET()
    netG.apply(weights_init)
    if distributed:
        netG = netG.cuda()
        netG = torch.nn.parallel.DistributedDataParallel(netG, device_ids=gpus, output_device=gpus[0], broadcast_buffers=True)
    else:
        if cfg.CUDA:
            netG = netG.cuda()
        netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    netsD = []
    if cfg.TREE.BRANCH_NUM > 0:
        netsD.append(D_NET64())
    if cfg.TREE.BRANCH_NUM > 1:
        netsD.append(D_NET128())
    if cfg.TREE.BRANCH_NUM > 2:
        netsD.append(D_NET256())
    if cfg.TREE.BRANCH_NUM > 3:
        netsD.append(D_NET512())
    if cfg.TREE.BRANCH_NUM > 4:
        netsD.append(D_NET1024())
    # TODO: if cfg.TREE.BRANCH_NUM > 5:
    # netsD_module = nn.ModuleList(netsD)
    # netsD_module.apply(weights_init)
    # netsD_module = torch.nn.parallel.DistributedDataParallel(netsD_module.cuda(), device_ids=gpus, output_device=gpus[0])
    for i in range(len(netsD)):
        netsD[i].apply(weights_init)
        if distributed:
            netsD[i] = torch.nn.parallel.DistributedDataParallel(netsD[i].cuda(), device_ids=gpus, output_device=gpus[0], broadcast_buffers=True
                # , process_group=pg_Ds[i]
                )
        else:
            netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
        print(netsD[i])
    print('# of netsD', len(netsD))

    count = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        print('Load ', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        count = cfg.TRAIN.NET_G[istart:iend]
        count = int(count) + 1

    if cfg.TRAIN.NET_D != '':
        for i in range(len(netsD)):
            print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
            state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i), map_location=lambda storage, loc: storage)
            netsD[i].load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    
    if not distributed:
        if cfg.CUDA:
            netG.cuda()
            for i in range(len(netsD)):
                netsD[i].cuda()
            inception_model = inception_model.cuda()
        inception_model = torch.nn.DataParallel(inception_model, device_ids=gpus)
    else:
        inception_model = torch.nn.parallel.DistributedDataParallel(inception_model.cuda(), device_ids=gpus, output_device=gpus[0])
        pass
    # inception_model = inception_model.cpu() #to(torch.device("cuda:{}".format(gpus[0])))
    inception_model.eval()
    print("model device, G:{}, D:{}, incep:{}".format(netG.device_ids, netsD[0].device_ids, inception_model.device_ids))
    return netG, netsD, len(netsD), inception_model, count