예제 #1
0
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions_custom.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
예제 #2
0
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        imsize=64
        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader)
    else:
        imsize=64
        datapath= '%s/test/' % (cfg.DATA_DIR)
        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="test", imsize=imsize, transform=img_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))
        algo = GANTrainer(output_dir)
        algo.sample(dataloader, num_samples=25, draw_bbox=True)
예제 #3
0
            if exc.errno == errno.EEXIST and os.path.isdir(path):
                pass
            else:
                raise

        copyfile(sys.argv[0], output_dir + "/" + sys.argv[0])
        copyfile("trainer.py", output_dir + "/" + "trainer.py")
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        imsize=64

        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform,
                              crop=True)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader)
    else:
        datapath = os.path.join(cfg.DATA_DIR, "test")
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=25, draw_bbox=True)
예제 #4
0
        elif cfg.STAGE == 2:
            resize = 268
            imsize = 256

        img_transform = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              cfg.IMG_DIR,
                              split="train",
                              imsize=imsize,
                              transform=img_transform,
                              crop=True,
                              stage=cfg.STAGE)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=25, stage=cfg.STAGE, draw_bbox=True)
예제 #5
0
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(args.manualSeed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, 'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath= '%s/test/val_captions.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
예제 #6
0
def main(gpu_id, data_dir, manual_seed, cuda, train_flag, image_size,
         batch_size, workers, stage, dataset_name, config_name, max_epoch,
         snapshot_interval, net_g, net_d, z_dim, generator_lr,
         discriminator_lr, lr_decay_epoch, coef_kl, stage1_g, embedding_type,
         condition_dim, df_dim, gf_dim, res_num, text_dim, regularizer):

    if manual_seed is None:
        manual_seed = random.randint(1, 10000)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    if cuda:
        torch.cuda.manual_seed_all(manual_seed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % (dataset_name, config_name, timestamp)

    num_gpu = len(gpu_id.split(','))
    if train_flag:
        image_transform = transforms.Compose([
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(data_dir,
                              'train',
                              imsize=image_size,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size *
                                                 num_gpu,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(workers))

        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.train(dataloader, stage, text_dim, gf_dim, condition_dim, z_dim,
                   df_dim, res_num)

    elif dataset_name == 'birds' and train_flag is False:
        image_transform = transforms.Compose([
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(data_dir,
                              'train',
                              imsize=image_size,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size *
                                                 num_gpu,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(workers))

        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.birds_eval(dataloader, stage)

    else:
        datapath = '%s/test/val_captions.t7' % (data_dir)
        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.sample(datapath, stage)
예제 #7
0
    # Initialize the main class which includes the training and evaluation
    algo = GANTrainer(output_dir, cfg_path=args.cfg_file)

    if cfg.TRAIN.FLAG:
        algo.train(dataloader)

    elif args.eval:
        date_str = datetime.datetime.now().strftime('%b-%d-%I%M%p-%G')

        if args.FID_eval:
            '''Do FID evaluations.'''
            f = open(os.path.join(output_dir, 'all_FID_eval.txt'), 'a')
            for net_G_name in net_G_names:
                cfg.NET_G = net_G_name
                algo.sample(dataloader, eval_name='eval',
                        eval_num=args.eval_num)
                fid_score_now = \
                    fid_scores(output_dir, cfg, sample_num=args.sample_num, 
                    gen_images_path=args.gen_paths, loop=True)

                f.write('%s, %s, %.4f\n' % (date_str, net_G_name, fid_score_now))

            f.close()

            # Save the best FID score model
            with open(os.path.join(output_dir, 'all_FID_eval.txt'), 'r') as f:
                all_lines = f.readlines()
                score_array =\
                        np.asarry([float(line.strip('\n').split(', ')[-1]) \
                        for line in all_lines])
                if fid_score_now == score_array.min():
예제 #8
0
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        if cfg.STAGE == 1:
            resize = 76
            imsize=64
        elif cfg.STAGE == 2:
            resize = 268
            imsize = 256

        img_transform = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, cfg.IMG_DIR, split="train", imsize=imsize, transform=img_transform,
                              crop=True, stage=cfg.STAGE)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        #https://github.com/taoxugit/AttnGAN/blob/0d000e652b407e976cb88fab299e8566f3de8a37/code/main.py#L146
        datapath= '%s/test/' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=5, stage=cfg.STAGE, draw_bbox=False)