예제 #1
0
파일: auto.py 프로젝트: aigagror/deep-plca
    def __init__(self, channels, imsize, zdim):
        super().__init__()

        # Encoder
        self.encoder = Encoder(channels, zdim)

        # Decoder
        self.decoder = Decoder(zdim, channels, imsize, hdim=256)
예제 #2
0
파일: auto.py 프로젝트: aigagror/deep-plca
    def __init__(self, channels, imsize, zdim):
        super().__init__()
        self.imsize = imsize

        # Image size should be a power of 2
        upsamples = math.log(imsize, 2)
        assert upsamples.is_integer()
        upsamples = int(upsamples)

        # Encoder
        self.encoder = Encoder(channels, zdim)

        draw_layers = []
        hdim = 128
        for i in range(upsamples):
            if i == 0:
                draw_layers.append(DrawLayer(zdim, hdim, channels, imsize))
            elif i == upsamples - 1:
                draw_layers.append(DrawLayer(hdim, channels, channels, imsize))
            else:
                draw_layers.append(DrawLayer(hdim, hdim, channels, imsize))

        self.draw_layers = nn.Sequential(*draw_layers)
예제 #3
0
mnist_train_dataset = CustomTensorDataset(torch.from_numpy(mnist_x_train).float(), torch.from_numpy(mnist_y_train).long(), transform=train_transform)
kannada_train_dataset = CustomTensorDataset(torch.from_numpy(kannada_x_train).float(), torch.from_numpy(kannada_y_train).long(), transform=train_transform)
kannada_val_dataset = CustomTensorDataset(torch.from_numpy(kannada_x_val).float(), torch.from_numpy(kannada_y_val).long(), transform=train_transform)

train_loader_1 = DataLoader(mnist_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
train_loader_2 = DataLoader(kannada_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_loader = DataLoader(kannada_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

num_batch = min(len(train_loader_1), len(train_loader_2))

dim_z = 512

# Modules
shared_latent = GaussianVAE2D(dim_z, dim_z, kernel_size=1, stride=1).to(device)
encoder_1 = Encoder().to(device)
encoder_2 = Encoder().to(device)
decoder_1 = Decoder().to(device)
decoder_2 = Decoder().to(device)
discriminator_1 = Discriminator().to(device)
discriminator_2 = Discriminator().to(device)

criterion_discr = nn.BCELoss(reduction='sum')
criterion_class = nn.CrossEntropyLoss(reduction='sum')

fixed_noise = torch.randn(args.batch_size, dim_z, 1, 1, device=device)
real_label = 1
fake_label = 0
label_noise = 0.1

# setup optimizer
예제 #4
0
def main():
    import pydevd_pycharm
    pydevd_pycharm.settrace('172.26.3.54',
                            port=12345,
                            stdoutToServer=True,
                            stderrToServer=True)
    parser = argparse.ArgumentParser(description="DFDGAN Showing G pic")
    parser.add_argument("--config_file",
                        default="./configs/show_pic.yml",
                        help="path to config file",
                        type=str)
    parser.add_argument("opts",
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    datasets_dir = ''
    for dataset_name in cfg.DATASETS.NAMES:
        if datasets_dir != '':
            datasets_dir += '-'
        datasets_dir += dataset_name
    output_dir = os.path.join(output_dir, datasets_dir)
    time_string = 'show_pic[{}]'.format(
        time.strftime('%Y-%m-%d-%X', time.localtime(time.time())))
    output_dir = os.path.join(output_dir, time_string)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    device = cfg.TEST.DEVICE
    if device == "cuda":
        os.environ['CUDA_VISIBLE_DEVICES'] = cfg.TEST.DEVICE_ID
    cudnn.benchmark = True
    logger = setup_logger("DFDGAN", output_dir, 0)
    logger.info("Running with config:\n{}".format(cfg))

    data_loader, num_classes = make_dataloaders(cfg)
    E = Encoder(num_classes, cfg.E.LAST_STRIDE, cfg.E.PRETRAIN_PATH,
                cfg.E.NECK, cfg.TEST.NECK_FEAT, cfg.E.NAME,
                cfg.E.PRETRAIN_CHOICE).to(device)
    Ed = Encoder(num_classes, cfg.ED.LAST_STRIDE, cfg.ED.PRETRAIN_PATH,
                 cfg.ED.NECK, cfg.TEST.NECK_FEAT, cfg.ED.NAME,
                 cfg.ED.PRETRAIN_CHOICE).to(device)
    G = DFDGenerator(cfg.G.PRETRAIN_PATH,
                     cfg.G.PRETRAIN_CHOICE,
                     noise_size=cfg.TRAIN.NOISE_SIZE).to(device)
    for _, batch in enumerate(data_loader):
        img_x1, img_x2, img_y1, img_y2, target_pid, target_setid = batch
        img_x1, img_x2, img_y1, img_y2, target_pid, target_setid = img_x1.to(
            device), img_x2.to(device), img_y1.to(device), img_y2.to(
                device), target_pid.to(device), target_setid.to(device)
        g_img = G(E(img_x1)[0], Ed(img_y1)[0])
        img_x1_PIL = transforms.ToPILImage()(img_x1[0].cpu()).convert('RGB')
        img_x1_PIL.save(os.path.join(output_dir, 'img_x1.jpg'))
        img_y1_PIL = transforms.ToPILImage()(img_y1[0].cpu()).convert('RGB')
        img_y1_PIL.save(os.path.join(output_dir, 'img_y1.jpg'))
        g_img_PIL = transforms.ToPILImage()(g_img[0].cpu()).convert('RGB')
        g_img_PIL.save(os.path.join(output_dir, 'g_img.jpg'))
        break
예제 #5
0
    torch.from_numpy(kannada_y_train).long(),
    transform=train_transform)

mnist_loader = DataLoader(mnist_train_dataset,
                          batch_size=args.batch_size,
                          shuffle=False,
                          num_workers=args.num_workers)
kannada_loader = DataLoader(kannada_train_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)

# Modules
dim_z = 512
shared_latent = GaussianVAE2D(dim_z, dim_z, kernel_size=1, stride=1).to(device)
encoder_1 = Encoder().to(device)
encoder_2 = Encoder().to(device)
decoder_1 = Decoder().to(device)
decoder_2 = Decoder().to(device)
discriminator_1 = Discriminator().to(device)
discriminator_2 = Discriminator().to(device)

if os.path.isfile(args.model):
    print("===> Loading Checkpoint to Evaluate '{}'".format(args.model))
    checkpoint = torch.load(args.model)
    shared_latent.load_state_dict(checkpoint['shared_latent'])
    encoder_1.load_state_dict(checkpoint['encoder_1'])
    encoder_2.load_state_dict(checkpoint['encoder_2'])
    decoder_1.load_state_dict(checkpoint['decoder_1'])
    decoder_2.load_state_dict(checkpoint['decoder_2'])
    discriminator_1.load_state_dict(checkpoint['discriminator_1'])