コード例 #1
0
    def __init__(self, networks, txt_emb):
        # builds an evaluator for running tests on model
        # most of this is not yet functional -- pulled from notebooks

        image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent = networks

        self.txt_emb = txt_emb
        self.text_encoder = text_encoder
        self.text_generator = text_generator
        self.image_encoder = image_encoder
        self.image_generator = image_generator

        # load fasttext embeddings (e.g., birds.en.vec)
        path = os.path.join(cfg.DATA_DIR, cfg.DATASET_NAME + ".en.vec")
        txt_dico, _txt_emb = load_external_embeddings(path)
        txt_emb = nn.Embedding(len(txt_dico), 300, sparse=False)
        txt_emb.weight.data.copy_(_txt_emb)
        txt_emb.weight.requires_grad = False
        self.txt_dico = txt_dico
        self.txt_emb = txt_emb

        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        dataset = TextDataset(cfg.DATA_DIR,
                              'test',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=10,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(cfg.WORKERS))

        self.dataloader = dataloader

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]  # 0
コード例 #2
0
        torch.cuda.manual_seed_all(args.manualSeed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions_custom.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
コード例 #3
0
            if exc.errno == errno.EEXIST and os.path.isdir(path):
                pass
            else:
                raise

        copyfile(sys.argv[0], output_dir + "/" + sys.argv[0])
        copyfile("trainer.py", output_dir + "/" + "trainer.py")
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        imsize=64

        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform,
                              crop=True)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader)
    else:
        datapath = os.path.join(cfg.DATA_DIR, "test")
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=25, draw_bbox=True)
コード例 #4
0
    print(cfg.GPU_ID)
    num_gpu = len(cfg.GPU_ID.split(','))
    #num_gpu = 8
    print(num_gpu)

    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.Resize(cfg.IMSIZE),
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              num_classes=cfg.GAN.CONDITION_DIM,
                              csv_file='genre_train.csv',
                              split='train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
コード例 #5
0
ファイル: main.py プロジェクト: Kinetize/InstaGAN
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform,
                              embedding_type=cfg.EMBEDDING_TYPE)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
コード例 #6
0
                pass
            else:
                raise

        copyfile(sys.argv[0], output_dir + "/" + sys.argv[0])
        copyfile("trainer.py", output_dir + "/" + "trainer.py")
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        imsize=64
        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader)
    else:
        imsize=64
        datapath= '%s/test/' % (cfg.DATA_DIR)
        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="test", imsize=imsize, transform=img_transform)
        assert dataset
コード例 #7
0
        torch.cuda.manual_seed_all(args.manualSeed)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              'train',
                              embedding_type='lstm',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
コード例 #8
0
        if cfg.STAGE == 1:
            resize = 76
            imsize = 64
        elif cfg.STAGE == 2:
            resize = 268
            imsize = 256

        img_transform = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              cfg.IMG_DIR,
                              split="train",
                              imsize=imsize,
                              transform=img_transform,
                              crop=True,
                              stage=cfg.STAGE)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/' % (cfg.DATA_DIR)
コード例 #9
0
# train_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=True, download=True,
#                    transform=transforms.ToTensor()),
#     batch_size=args.batch_size, shuffle=True, **kwargs)
# test_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
#     batch_size=args.batch_size, shuffle=True, **kwargs)

image_transform = transforms.Compose([
    transforms.RandomCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = TextDataset(DATA_DIR,
                            'train',
                            imsize=64,
                            transform=image_transform)
assert train_dataset
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True)

######

image_transform = transforms.Compose([
    transforms.RandomCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])