예제 #1
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import Generator, MSG_GAN
    from torch.nn import DataParallel

    # create a generator:
    msg_gan_generator = Generator(depth=args.depth,
                                  latent_size=args.latent_size).to(device)

    if device == th.device("cuda"):
        msg_gan_generator = DataParallel(msg_gan_generator)

    if args.generator_file is not None:
        # load the weights into generator
        msg_gan_generator.load_state_dict(th.load(args.generator_file))

    print("Loaded Generator Configuration: ")
    print(msg_gan_generator)

    # generate all the samples in a list of lists:
    samples = []  # start with an empty list
    for _ in range(args.num_samples):
        gen_samples = msg_gan_generator(th.randn(1, args.latent_size))
        samples.append(gen_samples)

        if args.show_samples:
            for gen_sample in gen_samples:
                plt.figure()
                plt.imshow(
                    th.squeeze(gen_sample.detach()).permute(1, 2, 0) / 2 + 0.5)
            plt.show()

    # create a grid of the generated samples:
    file_names = []  # initialize to empty list
    for res_val in range(args.depth):
        res_dim = np.power(2, res_val + 2)
        file_name = os.path.join(args.output_dir,
                                 str(res_dim) + "_" + str(res_dim) + ".png")
        file_names.append(file_name)

    images = list(map(lambda x: th.cat(x, dim=0), zip(*samples)))
    MSG_GAN.create_grid(images, file_names)

    print("samples have been generated. Please check:", args.output_dir)
예제 #2
0
파일: train.py 프로젝트: manojtld/BMSG-GAN
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import MSG_GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset
    from MSG_GAN import Losses as lses

    # create a data source:
    data_source = FlatDirectoryImageDataset if not args.folder_distributed \
        else FoldersDistributedDataset

    dataset = data_source(args.images_dir,
                          transform=get_transform(
                              (int(np.power(2, args.depth + 1)),
                               int(np.power(2, args.depth + 1))),
                              flip_horizontal=args.flip_augment))

    data = get_data_loader(dataset, args.batch_size, args.num_workers)
    print("Total number of images in the dataset:", len(dataset))

    # create a gan from these
    msg_gan = MSG_GAN(depth=args.depth,
                      latent_size=args.latent_size,
                      use_eql=args.use_eql,
                      use_ema=args.use_ema,
                      ema_decay=args.ema_decay,
                      device=device)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator_weights from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    print("Generator Configuration: ")
    # print(msg_gan.gen)

    if args.shadow_generator_file is not None:
        # load the weights into generator
        print("loading shadow_generator_weights from:",
              args.shadow_generator_file)
        msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file))

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator_weights from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Discriminator Configuration: ")
    # print(msg_gan.dis)

    # create optimizer for generator:
    gen_optim = th.optim.Adam(msg_gan.gen.parameters(), args.g_lr,
                              [args.adam_beta1, args.adam_beta2])

    dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr,
                              [args.adam_beta1, args.adam_beta2])

    if args.generator_optim_file is not None:
        print("loading gen_optim_state from:", args.generator_optim_file)
        gen_optim.load_state_dict(th.load(args.generator_optim_file))

    if args.discriminator_optim_file is not None:
        print("loading dis_optim_state from:", args.discriminator_optim_file)
        dis_optim.load_state_dict(th.load(args.discriminator_optim_file))

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = lses.HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = lses.RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = lses.StandardGAN
    elif loss_name == "lsgan":
        loss = lses.LSGAN
    elif loss_name == "lsgan-sigmoid":
        loss = lses.LSGAN_SIGMOID
    elif loss_name == "wgan-gp":
        loss = lses.WGAN_GP
    else:
        raise Exception("Unknown loss function requested")

    # train the GAN
    msg_gan.train(data,
                  gen_optim,
                  dis_optim,
                  loss_fn=loss(msg_gan.dis),
                  num_epochs=args.num_epochs,
                  checkpoint_factor=args.checkpoint_factor,
                  data_percentage=args.data_percentage,
                  feedback_factor=args.feedback_factor,
                  num_samples=args.num_samples,
                  sample_dir=args.sample_dir,
                  save_dir=args.model_dir,
                  log_dir=args.model_dir,
                  start=args.start)
예제 #3
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    #from pro_gan_pytorch.PRO_GAN import ConditionalProGAN
    from MSG_GAN.GAN import MSG_GAN
    from MSG_GAN import Losses as lses

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    if config.use_pretrained_encoder:
        dataset = dl.RawTextFace2TextDataset(
            annots_file=config.annotations_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims)
        )
        from networks.TextEncoder import PretrainedEncoder
        # create a new session object for the pretrained encoder:
        text_encoder = PretrainedEncoder(
            model_file=config.pretrained_encoder_file,
            embedding_file=config.pretrained_embedding_file,
            device=device
        )
        encoder_optim = None
    else:
        dataset = dl.Face2TextDataset(
            pro_pick_file=config.processed_text_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims),
            captions_len=config.captions_length
        )
        text_encoder = Encoder(
            embedding_size=config.embedding_size,
            vocab_size=dataset.vocab_size,
            hidden_size=config.hidden_size,
            num_layers=config.num_layers,
            device=device
        )
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.adam_beta1, config.adam_beta2),
                                      eps=config.eps)
    msg_gan = MSG_GAN(
        depth=config.depth,
        latent_size=config.latent_size,
        use_eql=config.use_eql,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    genoptim = th.optim.Adam(msg_gan.gen.parameters(), config.g_lr,
                              [config.adam_beta1, config.adam_beta2])

    disoptim = th.optim.Adam(msg_gan.dis.parameters(), config.d_lr,
                              [config.adam_beta1, config.adam_beta2])

    loss = lses.RelativisticAverageHingeGAN

    # create the networks

    if args.encoder_file is not None:
        # Note this should not be used with the pretrained encoder file
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(
        input_size=config.hidden_size,
        latent_size=config.ca_out_size,
        use_eql=config.use_eql,
        device=device
    )

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.adam_beta1, config.adam_beta2),
                             eps=config.eps)

    print("Generator Config:")
    print(msg_gan.gen)

    print("\nDiscriminator Config:")
    print(msg_gan.dis)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        msg_gan=msg_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        gen_optim=genoptim,
        dis_optim=disoptim,
        loss_fn=loss(msg_gan.dis),
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
    )
예제 #4
0
def main():
    import os
    import numpy as np
    import torch as th
    from torch.backends import cudnn
    cudnn.benchmark = True
    device = th.device("cuda" if th.cuda.is_available() else "cpu")

    from pinglib.files import get_file_list, create_dir
    from pinglib.utils import save_variables
    from PIL import Image

    image_folder = r"D:\Projects\anomaly_detection\datasets\Camelyon\test_negative"
    save_path = r"D:\Projects\anomaly_detection\BMSG_GAN_test_neg.pkl"
    model_path=r"D:\Projects\anomaly_detection\progresses\MSG-GAN\Models\GAN_DIS_73.pth"

    '''-----------------建立数据集和数据载入器----------------'''

    from torch.utils.data import Dataset
    from torchvision.transforms import ToTensor, Resize, Compose, Normalize

    class Dataset4extract(Dataset):
        def __init__(self, image_paths):
            self.image_paths = image_paths
            self.transform = Compose([
                ToTensor(),
                Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ])

        def __len__(self):
            return len(self.image_paths)

        def __getitem__(self, idx):
            img = Image.open(self.image_paths[idx])

            img = self.transform(img)

            if img.shape[0] == 4:
                # ignore the alpha channel
                # in the image if it exists
                img = img[:3, :, :]
            return img

    image_paths = get_file_list(image_folder, ext='jpg')
    dataset = Dataset4extract(image_paths)
    print("Total number of images in the dataset:", len(dataset))

    from torch.utils.data import DataLoader
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

    '''-----------------建立模型----------------'''
    from MSG_GAN.GAN import MSG_GAN
    depth = 7
    msg_gan = MSG_GAN(depth=depth,
                      latent_size=512,
                      use_eql=True,
                      use_ema=True,
                      ema_decay=0.999,
                      device=device)

    msg_gan.dis.load_state_dict(th.load(model_path))

    '''-----------------进行评估----------------'''
    features = []
    from torch.nn.functional import avg_pool2d

    for (i, batch) in enumerate(dataloader):
        #   获取多分辨率的图像输入
        images = batch.to(device)

        images = [images] + [avg_pool2d(images, int(np.power(2, i)))
                             for i in range(1, depth)]
        images = list(reversed(images))

        #   把这些图像丢给模型
        feature = msg_gan.extract(images)
        features.append(feature.detach().cpu().numpy())

    '''-----------------保存结果----------------'''
    features = np.concatenate(features, axis=0)
    save_variables([features], save_path)
예제 #5
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import MSG_GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset, IgnoreLabels
    from MSG_GAN import Losses as lses

    # create a data source:
    if args.pytorch_dataset is None:
        data_source = FlatDirectoryImageDataset if not args.folder_distributed \
            else FoldersDistributedDataset

        dataset = data_source(args.images_dir,
                              transform=get_transform(
                                  (int(np.power(2, args.depth + 1)),
                                   int(np.power(2, args.depth + 1))),
                                  flip_horizontal=args.flip_augment))
    else:
        dataset_name = args.pytorch_dataset.lower()
        if dataset_name == "cifar10":
            dataset = IgnoreLabels(
                CIFAR10(args.dataset_dir,
                        transform=get_transform(
                            (int(np.power(2, args.depth + 1)),
                             int(np.power(2, args.depth + 1))),
                            flip_horizontal=args.flip_augment),
                        download=True))
        else:
            raise Exception("Unknown dataset requested")

    data = get_data_loader(dataset, args.batch_size, args.num_workers)
    print("Total number of images in the dataset:", len(dataset))

    # create a gan from these
    msg_gan = MSG_GAN(depth=args.depth,
                      latent_size=args.latent_size,
                      use_eql=args.use_eql,
                      use_ema=args.use_ema,
                      ema_decay=args.ema_decay,
                      device=device)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator_weights from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    # print("Generator Configuration: ")
    # print(msg_gan.gen)

    if args.shadow_generator_file is not None:
        # load the weights into generator
        print("loading shadow_generator_weights from:",
              args.shadow_generator_file)
        msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file))

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator_weights from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # print("Discriminator Configuration: ")
    # print(msg_gan.dis)

    # create optimizer for generator:
    gen_params = [{
        'params': msg_gan.gen.style.parameters(),
        'lr': args.g_lr * 0.01,
        'mult': 0.01
    }, {
        'params': msg_gan.gen.layers.parameters(),
        'lr': args.g_lr
    }, {
        'params': msg_gan.gen.rgb_converters.parameters(),
        'lr': args.g_lr
    }]
    gen_optim = th.optim.Adam(gen_params, args.g_lr,
                              [args.adam_beta1, args.adam_beta2])

    dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr,
                              [args.adam_beta1, args.adam_beta2])

    if args.generator_optim_file is not None:
        print("loading gen_optim_state from:", args.generator_optim_file)
        gen_optim.load_state_dict(th.load(args.generator_optim_file))

    if args.discriminator_optim_file is not None:
        print("loading dis_optim_state from:", args.discriminator_optim_file)
        dis_optim.load_state_dict(th.load(args.discriminator_optim_file))

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = lses.HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = lses.RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = lses.StandardGAN
    elif loss_name == "lsgan":
        loss = lses.LSGAN
    elif loss_name == "lsgan-sigmoid":
        loss = lses.LSGAN_SIGMOID
    elif loss_name == "wgan-gp":
        loss = lses.WGAN_GP
    else:
        raise Exception("Unknown loss function requested")

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    if args.pytorch_dataset is not None:
        dataName = 'cifar'
    elif args.images_dir.find('celeb') != -1:
        dataName = 'celeb'
    else:
        dataName = 'flowers'
    output_dir = 'output/%s_%s_%s' % \
        ('attnmsggan', dataName, timestamp)
    args.model_dir = output_dir + '/models'
    args.sample_dir = output_dir + '/images'
    args.log_dir = output_dir + '/logs'

    # train the GAN
    msg_gan.train(data,
                  gen_optim,
                  dis_optim,
                  loss_fn=loss(msg_gan.dis),
                  num_epochs=args.num_epochs,
                  checkpoint_factor=args.checkpoint_factor,
                  data_percentage=args.data_percentage,
                  feedback_factor=args.feedback_factor,
                  num_samples=args.num_samples,
                  sample_dir=args.sample_dir,
                  save_dir=args.model_dir,
                  log_dir=args.log_dir,
                  start=args.start)