def test(config):
    print(config)

    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

    if not exists(join(config.result_path, 'test-%d' % config.test_epoch)):
        os.makedirs(join(config.result_path, 'test-%d' % config.test_epoch))

    test_dataset = TestDataset(test_num=2 * config.rows * config.cols)
    test_loader = DataLoader(test_dataset,
                             batch_size=config.rows * config.cols,
                             num_workers=1)

    generator = Generator(resolution=config.resolution,
                          output_act=config.output_act,
                          norm=config.norm,
                          device=device).to(device)

    if config.load_G:
        if torch.cuda.is_available():
            generator.load_model(config.load_G)
        else:
            generator.load_model(config.load_G, map_location='cpu')

    generator.eval()

    for i, data in enumerate(tqdm(test_loader)):
        noises = data['noise'].float().to(device)
        fake_images = generator(noises, alpha=1.0)

        if config.output_act == 'tanh':
            fake_images = (fake_images.detach().cpu().numpy()
                           [0:config.rows * config.cols].transpose(
                               (0, 2, 3, 1)) + 1.) * 0.5
        else:
            fake_images = fake_images.detach().cpu().numpy(
            )[0:config.rows * config.cols].transpose((0, 2, 3, 1))

        save_result(rows=config.rows,
                    cols=config.cols,
                    images=fake_images,
                    result_file=join(config.result_path,
                                     'test-%d' % config.test_epoch,
                                     "fake-%d.png" % i))
Exemple #2
0
    img_folder = pathlib.Path(args.input_folder)
    for img_path in tqdm(get_file_list(args.input_folder, p_postfix=['.jpg'])):
        img = cv2.imread(img_path)
        # if img.shape[0] /img.shape[1] > 2 or img.shape[1]/img.shape[0] > 2:
        #     continue
        preds, boxes_list, score_list, t = model.predict(
            img_path, is_output_polygon=args.polygon, runtime='trt')
        print('time cost: {}s'.format(t))
        crops = crop_bbox(img[:, :, ::-1], boxes_list)
        img = draw_bbox(img[:, :, ::-1], boxes_list)
        if args.show:
            show_img(preds)
            show_img(img, title=os.path.basename(img_path))
            plt.show()
        # 保存结果到路径
        os.makedirs(args.output_folder, exist_ok=True)
        img_path = pathlib.Path(img_path)
        output_path = os.path.join(args.output_folder,
                                   img_path.stem + '_result.jpg')
        pred_path = os.path.join(args.output_folder,
                                 img_path.stem + '_pred.jpg')
        cv2.imwrite(output_path, img[:, :, ::-1])
        cv2.imwrite(pred_path, preds * 255)
        for i, crop in enumerate(crops):
            cv2.imwrite(
                os.path.join(args.output_folder,
                             img_path.stem + '_text_{:02d}.jpg'.format(i)),
                crop)
        save_result(output_path.replace('_result.jpg', '.txt'), boxes_list,
                    score_list, args.polygon)
def crossModalQueries(embeddings=None, topk=5, mode1="au", mode2="im", use_tags=False, result_path=None, plot=False):

    if plot and topk != 5:
        raise ValueError("When plot is True, topk must be 5.")

    finalTag = getNumToTagsMap()
    # print(finalTag)

    for r, di, files in os.walk("./data/test/audio"):
        audioFiles = sorted(files)

    t = torch.load(embeddings)

    for i in [2, 3]:
        t[i] = np.concatenate(t[i])
    # Generalize here
    if len(t) == 6:
        imgList, audList, imgEmbedList, audEmbedList, vidTagList, audTagList = t
    elif len(t) == 7:
        imgList, audList, imgEmbedList, audEmbedList, vidTagList, audTagList, audioSampleList = t
    elif len(t) == 4:
        imgList, audList, imgEmbedList, audEmbedList = t
    else:
        raise ValueError("Invalid number of items: Found {} in 'savedEmbeddings.pt'".format(len(t)))

    print("Loaded embeddings.")

    # imgList = bgr2rgb(imgList)

    print("Size of data : " + str(len(imgEmbedList)))

    # Open a file and store your queries here
    if plot:
        res = open("results/results_{0}_{1}.txt".format(mode1, mode2), "w+")

    assert mode1 != mode2

    res_queries = []
    res_tags = []
    for i in range(len(imgEmbedList)):
        if mode1 == "im":
            embed = imgEmbedList[i]
        else:
            embed = audEmbedList[i]

        # Compute distance
        if mode2 == "im":
            dist = ((embed - imgEmbedList) ** 2).sum(1)
        else:
            dist = ((embed - audEmbedList) ** 2).sum(1)

        # Sort arguments
        idx = dist.argsort()[:topk]
        if use_tags:
            # print(vidTagList[idx])
            pass
        if plot:
            plt.clf()
        num_fig = idx.shape[0]

        # Actual query
        if use_tags:
            if plot:
                ax = plt.subplot(2, 3, 1)
                ax.set_title("Query: " + str([finalTag[x] for x in vidTagList[i]]))
            res_query = [finalTag[x] for x in vidTagList[i]]

        if plot:
            plt.axis("off")
            plt.imshow(imgList[i].squeeze().transpose(1, 2, 0))

        # Top k matches
        res_tag = []
        for j in range(num_fig):
            if use_tags:
                res_tag_ = [finalTag[x] for x in vidTagList[idx[j]]]
                if plot:
                    ax = plt.subplot(2, 3, j + 2)
                    ax.set_title(str(res_tag_))
                res_tag.append(res_tag_)
            if plot:
                plt.imshow(imgList[idx[j]].squeeze().transpose(1, 2, 0))
                plt.axis("off")

        # plt.tight_layout()

        if plot:
            plt.draw()
            plt.pause(0.001)
            flag = True
            if flag:
                input()
                flag = False
            ans = input("Do you want to save? (quit: q): ")
            if ans == "q":
                break
            elif ans == "y":
                if mode1 == "au":
                    res.write(audioFiles[audioSampleList[i][0]] + "\n")
                    print(audioFiles[audioSampleList[i][0]])
                else:
                    tmpFiles = map(lambda x: audioFiles[x], idx)
                    line = ", ".join(tmpFiles)
                    print(line)
                    res.write(line + "\n")
                plt.savefig("results/embed_{0}_{1}_{2}.png".format(mode1, mode2, i))

        res_queries.append(res_query)
        res_tags.append(res_tag)
    save_result(result_path, res_queries, res_tags)
    if plot:
        res.close()
Exemple #4
0
def init_args():
    import argparse
    parser = argparse.ArgumentParser(description='DBNet.pytorch')
    parser.add_argument('--model_path', default='./output/DBNet_resnet18_FPN_DBHead/checkpoint/model_best.pth', type=str)
    parser.add_argument('--img_path', default='./input/img_1.jpg', type=str, help='img path for predict')
    parser.add_argument('--polygon', action='store_true', help='output polygon or box')
    parser.add_argument('--show', action='store_true', help='show result')
    parser.add_argument('--save_resut',action='store_true', help='save box and score to txt file')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from utils.util import show_img, draw_bbox, save_result

    args = init_args()
    print(args)
    os.environ['CUDA_VISIBLE_DEVICES'] = str('0')
    # 初始化网络
    model = Pytorch_model(args.model_path, gpu_id=0)
    preds, boxes_list, score_list, t = model.predict(args.img_path, is_output_polygon=args.polygon)
    if args.show:
        show_img(preds)
        img = draw_bbox(cv2.imread(args.img_path)[:, :, ::-1], boxes_list)
        show_img(img)
        plt.show()
    if args.save_resut:
        # 保存结果到路径
        save_result(args.img_path, boxes_list, score_list, args.polygon)
def AudioToAudioQueries(embeddings=None,
                        topk=5,
                        use_tags=False,
                        result_path=None,
                        plot=False):

    if plot and topk != 5:
        raise ValueError("When plot is True, topk must be 5.")

    finalTag = getNumToTagsMap()
    # print(finalTag)

    t = torch.load(embeddings)

    for i in [2, 3]:
        t[i] = np.concatenate(t[i])

    # Generalize here
    if len(t) == 6:
        imgList, audList, imgEmbedList, audEmbedList, vidTagList, audTagList = t
    elif len(t) == 7:
        imgList, audList, imgEmbedList, audEmbedList, vidTagList, audTagList, audioSampleList = t
    elif len(t) == 4:
        imgList, audList, imgEmbedList, audEmbedList = t
    else:
        raise ValueError(
            "Invalid number of items: Found {} in 'savedEmbeddings.pt'".format(
                len(t)))

    print("Loaded embeddings.")

    print("Size of data : " + str(len(audEmbedList)))

    res_queries = []
    res_tags = []
    for i in range(len(audEmbedList)):
        embed = audEmbedList[i]
        dist = ((embed - audEmbedList)**2).sum(1)
        idx = dist.argsort()[:topk]
        if use_tags:
            # print(audTagList[idx])
            pass

        num_fig = idx.shape[0]
        if plot:
            plt.clf()
            ax = plt.subplot(1, 3, 1)

        if use_tags:
            if plot:
                ax.set_title(finalTag[audTagList[idx[0]]])
            # res_query = finalTag[audTagList[idx[0]]]
            res_query = [finalTag[x] for x in audTagList[i]]
        if plot:
            plt.axis("off")
            plt.imshow(audList[idx[0]].transpose(1, 2, 0))

        res_tag = []
        for j in range(1, num_fig + 1):
            if plot:
                ax = plt.subplot(2, 3, j + 1 + int(j / 3))
            if use_tags:
                if plot:
                    ax.set_title(finalTag[audTagList[idx[j]]])
                # res_tag.append(finalTag[audTagList[idx[j]]])
                res_tag_ = [finalTag[x] for x in audTagList[idx[j - 1]]]
                res_tag.append(res_tag_)
            if plot:
                plt.imshow(audList[idx[j]].transpose(1, 2, 0))
                plt.axis("off")

        if plot:
            plt.draw()
            plt.pause(0.001)
            flag = True
            if flag:
                input()
            flag = False
            res = input("Do you want to save?")
            if res == "y":
                plt.savefig("results/embed_au_au_{0}.png".format(i))

        res_queries.append(res_query)
        res_tags.append(res_tag)
    save_result(result_path, res_queries, res_tags)
def train(config):
    print(config)

    if config.device_id:
        os.environ['CUDA_VISIBLE_DEVICES'] = config.device_id

    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

    if not exists(config.ckpt_path):
        os.makedirs(config.ckpt_path)

    if not exists(config.result_path):
        os.makedirs(config.result_path)

    with open(join(config.result_path, 'config.txt'), 'w') as f:
        f.write(str(config))

    writer = SummaryWriter(config.result_path)

    if config.output_act == 'linear':
        train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
    elif config.output_act == 'tanh':
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    train_dataset = TrainDataset(config.celeba_hq_dir,
                                 config.train_file,
                                 resolution=config.resolution,
                                 transform=train_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.num_workers)

    generator = Generator(resolution=config.resolution,
                          output_act=config.output_act,
                          norm=config.norm,
                          device=device).to(device)
    discriminator = Discriminator(resolution=config.resolution,
                                  device=device).to(device)

    ganLoss = GANLoss(gan_mode=config.gan_type,
                      target_real_label=0.9,
                      target_fake_label=0.).to(device)

    if config.gan_type == 'wgangp':
        gpLoss = GradientPenaltyLoss(device).to(device)

    if config.load_G:
        if torch.cuda.is_available():
            generator.load_model(config.load_G)
        else:
            generator.load_model(config.load_G, map_location='cpu')
        print('Loading %s' % config.load_G)

    if config.load_D:
        if torch.cuda.is_available():
            discriminator.load_model(config.load_D)
        else:
            discriminator.load_model(config.load_D, map_location='cpu')
        print('Loading %s' % config.load_D)

    g_lr = config.g_lr
    d_lr = config.d_lr

    optimG = torch.optim.Adam(params=generator.parameters(), lr=g_lr)
    optimD = torch.optim.Adam(params=discriminator.parameters(), lr=d_lr)

    # Loading config from last training
    if config.start_idx > 0:
        with open(
                join(config.result_path, 'config-%d.json' % config.start_idx),
                'r') as f:
            temp_data = json.load(f)
            alpha = temp_data['alpha']
            start_train_point = temp_data['start_train_point']
            g_lr = temp_data['g_lr']
            d_lr = temp_data['d_lr']

    if config.phase == 'fadein':
        if config.start_idx > 0:
            alpha = temp_data['alpha']
            start_train_point = temp_data['start_train_point']
            g_lr = temp_data['g_lr']
            d_lr = temp_data['d_lr']
        else:
            # alpha: [0, 0.9]
            alpha = 0.0
            start_train_point = 1
        epoch_length = config.epochs // 10
    elif config.phase == 'stabilize':
        if config.start_idx > 0:
            g_lr = temp_data['g_lr']
            d_lr = temp_data['d_lr']
        alpha = 1.0
        start_train_point = 0

    for epoch in range(1 + config.start_idx, config.epochs + 1):

        for i, data in enumerate(tqdm(train_loader)):
            real_images = data['image'].to(device)
            noises = data['noise'].float().to(device)

            optimG.zero_grad()
            fake_images = generator(noises, alpha)
            fake_labels = discriminator(fake_images, alpha)
            gan_loss = ganLoss(fake_labels, True)
            gan_loss.backward()
            optimG.step()

            optimD.zero_grad()
            dis_loss = ganLoss(discriminator(real_images, alpha), True) + \
                       ganLoss(discriminator(fake_images.detach(), alpha), False)
            if config.gan_type == 'wgangp':
                gp_loss = gpLoss(discriminator, real_images,
                                 fake_images.detach())
                dis_loss = dis_loss + config.l_gp * gp_loss
            dis_loss.backward()
            optimD.step()

            if config.phase == 'fadein':
                if epoch > epoch_length * start_train_point:
                    start_train_point += 1
                    alpha += 0.1

            if i % 500 == 0:
                if config.gan_type == 'wgangp':
                    print(
                        'Epoch: %d/%d | Step: %d/%d | G loss: %.4f | D loss: %.4f | gp loss: %.4f'
                        % (epoch, config.epochs, i, len(train_loader),
                           gan_loss.item(), dis_loss.item(), gp_loss.item()))
                else:
                    print(
                        'Epoch: %d/%d | Step: %d/%d | G loss: %.4f | D loss: %.4f'
                        % (epoch, config.epochs, i, len(train_loader),
                           gan_loss.item(), dis_loss.item()))

                if config.output_act == 'tanh':
                    fake_images = (
                        fake_images.detach().cpu().numpy()[0:6].transpose(
                            (0, 2, 3, 1)) + 1.) * 0.5
                    real_images = (
                        real_images.detach().cpu().numpy()[0:6].transpose(
                            (0, 2, 3, 1)) + 1.) * 0.5
                else:
                    fake_images = fake_images.detach().cpu().numpy(
                    )[0:6].transpose((0, 2, 3, 1))
                    real_images = real_images.detach().cpu().numpy(
                    )[0:6].transpose((0, 2, 3, 1))
                save_result(rows=2,
                            cols=3,
                            images=fake_images,
                            result_file=join(
                                config.result_path,
                                "fake-epoch-%d-step-%d-alpha-%.1f.png" %
                                (epoch, i, alpha)))
                save_result(rows=2,
                            cols=3,
                            images=real_images,
                            result_file=join(
                                config.result_path,
                                "real-epoch-%d-step-%d.png" % (epoch, i)))

                writer.add_scalars('loss', {
                    'G loss': gan_loss.item(),
                    'D loss': dis_loss.item()
                }, (epoch - 1) * len(train_loader) + i)

                if config.gan_type == 'wgangp':
                    writer.add_scalars('loss', {'gp': gp_loss.item()},
                                       (epoch - 1) * len(train_loader) + i)

        with open(join(config.result_path, 'config-%d.json' % epoch),
                  'w') as f:
            temp_data = {
                'alpha': alpha,
                'start_train_point': start_train_point,
                'g_lr': g_lr,
                'd_lr': d_lr,
            }
            json.dump(temp_data, f)

        if epoch % 4 == 0:
            generator.save_model(
                join(config.ckpt_path, 'G-epoch-%d.pkl' % epoch))
            discriminator.save_model(
                join(config.ckpt_path, 'D-epoch-%d.pkl' % epoch))

        if epoch % 20 == 0:
            g_lr *= 0.1
            d_lr *= 0.1
            optimG = torch.optim.Adam(params=generator.parameters(), lr=g_lr)
            optimD = torch.optim.Adam(params=discriminator.parameters(),
                                      lr=d_lr)

    writer.export_scalars_to_json(join(config.result_path, 'scalars.json'))
    writer.close()