def main(args): """ Main function for the script :param args: parsed command line arguments :return: None """ from MSG_GAN.GAN import Generator, MSG_GAN from torch.nn import DataParallel # create a generator: msg_gan_generator = Generator(depth=args.depth, latent_size=args.latent_size).to(device) if device == th.device("cuda"): msg_gan_generator = DataParallel(msg_gan_generator) if args.generator_file is not None: # load the weights into generator msg_gan_generator.load_state_dict(th.load(args.generator_file)) print("Loaded Generator Configuration: ") print(msg_gan_generator) # generate all the samples in a list of lists: samples = [] # start with an empty list for _ in range(args.num_samples): gen_samples = msg_gan_generator(th.randn(1, args.latent_size)) samples.append(gen_samples) if args.show_samples: for gen_sample in gen_samples: plt.figure() plt.imshow( th.squeeze(gen_sample.detach()).permute(1, 2, 0) / 2 + 0.5) plt.show() # create a grid of the generated samples: file_names = [] # initialize to empty list for res_val in range(args.depth): res_dim = np.power(2, res_val + 2) file_name = os.path.join(args.output_dir, str(res_dim) + "_" + str(res_dim) + ".png") file_names.append(file_name) images = list(map(lambda x: th.cat(x, dim=0), zip(*samples))) MSG_GAN.create_grid(images, file_names) print("samples have been generated. Please check:", args.output_dir)
def main(args): """ Main function for the script :param args: parsed command line arguments :return: None """ from MSG_GAN.GAN import MSG_GAN from data_processing.DataLoader import FlatDirectoryImageDataset, \ get_transform, get_data_loader, FoldersDistributedDataset from MSG_GAN import Losses as lses # create a data source: data_source = FlatDirectoryImageDataset if not args.folder_distributed \ else FoldersDistributedDataset dataset = data_source(args.images_dir, transform=get_transform( (int(np.power(2, args.depth + 1)), int(np.power(2, args.depth + 1))), flip_horizontal=args.flip_augment)) data = get_data_loader(dataset, args.batch_size, args.num_workers) print("Total number of images in the dataset:", len(dataset)) # create a gan from these msg_gan = MSG_GAN(depth=args.depth, latent_size=args.latent_size, use_eql=args.use_eql, use_ema=args.use_ema, ema_decay=args.ema_decay, device=device) if args.generator_file is not None: # load the weights into generator print("loading generator_weights from:", args.generator_file) msg_gan.gen.load_state_dict(th.load(args.generator_file)) print("Generator Configuration: ") # print(msg_gan.gen) if args.shadow_generator_file is not None: # load the weights into generator print("loading shadow_generator_weights from:", args.shadow_generator_file) msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file)) if args.discriminator_file is not None: # load the weights into discriminator print("loading discriminator_weights from:", args.discriminator_file) msg_gan.dis.load_state_dict(th.load(args.discriminator_file)) print("Discriminator Configuration: ") # print(msg_gan.dis) # create optimizer for generator: gen_optim = th.optim.Adam(msg_gan.gen.parameters(), args.g_lr, [args.adam_beta1, args.adam_beta2]) dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr, [args.adam_beta1, args.adam_beta2]) if args.generator_optim_file is not None: print("loading gen_optim_state from:", args.generator_optim_file) gen_optim.load_state_dict(th.load(args.generator_optim_file)) if args.discriminator_optim_file is not None: print("loading dis_optim_state from:", args.discriminator_optim_file) dis_optim.load_state_dict(th.load(args.discriminator_optim_file)) loss_name = args.loss_function.lower() if loss_name == "hinge": loss = lses.HingeGAN elif loss_name == "relativistic-hinge": loss = lses.RelativisticAverageHingeGAN elif loss_name == "standard-gan": loss = lses.StandardGAN elif loss_name == "lsgan": loss = lses.LSGAN elif loss_name == "lsgan-sigmoid": loss = lses.LSGAN_SIGMOID elif loss_name == "wgan-gp": loss = lses.WGAN_GP else: raise Exception("Unknown loss function requested") # train the GAN msg_gan.train(data, gen_optim, dis_optim, loss_fn=loss(msg_gan.dis), num_epochs=args.num_epochs, checkpoint_factor=args.checkpoint_factor, data_percentage=args.data_percentage, feedback_factor=args.feedback_factor, num_samples=args.num_samples, sample_dir=args.sample_dir, save_dir=args.model_dir, log_dir=args.model_dir, start=args.start)
def main(args): """ Main function for the script :param args: parsed command line arguments :return: None """ from networks.TextEncoder import Encoder from networks.ConditionAugmentation import ConditionAugmentor #from pro_gan_pytorch.PRO_GAN import ConditionalProGAN from MSG_GAN.GAN import MSG_GAN from MSG_GAN import Losses as lses print(args.config) config = get_config(args.config) print("Current Configuration:", config) # create the dataset for training if config.use_pretrained_encoder: dataset = dl.RawTextFace2TextDataset( annots_file=config.annotations_file, img_dir=config.images_dir, img_transform=dl.get_transform(config.img_dims) ) from networks.TextEncoder import PretrainedEncoder # create a new session object for the pretrained encoder: text_encoder = PretrainedEncoder( model_file=config.pretrained_encoder_file, embedding_file=config.pretrained_embedding_file, device=device ) encoder_optim = None else: dataset = dl.Face2TextDataset( pro_pick_file=config.processed_text_file, img_dir=config.images_dir, img_transform=dl.get_transform(config.img_dims), captions_len=config.captions_length ) text_encoder = Encoder( embedding_size=config.embedding_size, vocab_size=dataset.vocab_size, hidden_size=config.hidden_size, num_layers=config.num_layers, device=device ) encoder_optim = th.optim.Adam(text_encoder.parameters(), lr=config.learning_rate, betas=(config.adam_beta1, config.adam_beta2), eps=config.eps) msg_gan = MSG_GAN( depth=config.depth, latent_size=config.latent_size, use_eql=config.use_eql, use_ema=config.use_ema, ema_decay=config.ema_decay, device=device) genoptim = th.optim.Adam(msg_gan.gen.parameters(), config.g_lr, [config.adam_beta1, config.adam_beta2]) disoptim = th.optim.Adam(msg_gan.dis.parameters(), config.d_lr, [config.adam_beta1, config.adam_beta2]) loss = lses.RelativisticAverageHingeGAN # create the networks if args.encoder_file is not None: # Note this should not be used with the pretrained encoder file print("Loading encoder from:", args.encoder_file) text_encoder.load_state_dict(th.load(args.encoder_file)) condition_augmenter = ConditionAugmentor( input_size=config.hidden_size, latent_size=config.ca_out_size, use_eql=config.use_eql, device=device ) if args.ca_file is not None: print("Loading conditioning augmenter from:", args.ca_file) condition_augmenter.load_state_dict(th.load(args.ca_file)) if args.generator_file is not None: print("Loading generator from:", args.generator_file) msg_gan.gen.load_state_dict(th.load(args.generator_file)) if args.discriminator_file is not None: print("Loading discriminator from:", args.discriminator_file) msg_gan.dis.load_state_dict(th.load(args.discriminator_file)) # create the optimizer for Condition Augmenter separately ca_optim = th.optim.Adam(condition_augmenter.parameters(), lr=config.learning_rate, betas=(config.adam_beta1, config.adam_beta2), eps=config.eps) print("Generator Config:") print(msg_gan.gen) print("\nDiscriminator Config:") print(msg_gan.dis) # train all the networks train_networks( encoder=text_encoder, ca=condition_augmenter, msg_gan=msg_gan, dataset=dataset, encoder_optim=encoder_optim, ca_optim=ca_optim, gen_optim=genoptim, dis_optim=disoptim, loss_fn=loss(msg_gan.dis), epochs=config.epochs, fade_in_percentage=config.fade_in_percentage, start_depth=args.start_depth, batch_sizes=config.batch_sizes, num_workers=config.num_workers, feedback_factor=config.feedback_factor, log_dir=config.log_dir, sample_dir=config.sample_dir, checkpoint_factor=config.checkpoint_factor, save_dir=config.save_dir, )
def main(): import os import numpy as np import torch as th from torch.backends import cudnn cudnn.benchmark = True device = th.device("cuda" if th.cuda.is_available() else "cpu") from pinglib.files import get_file_list, create_dir from pinglib.utils import save_variables from PIL import Image image_folder = r"D:\Projects\anomaly_detection\datasets\Camelyon\test_negative" save_path = r"D:\Projects\anomaly_detection\BMSG_GAN_test_neg.pkl" model_path=r"D:\Projects\anomaly_detection\progresses\MSG-GAN\Models\GAN_DIS_73.pth" '''-----------------建立数据集和数据载入器----------------''' from torch.utils.data import Dataset from torchvision.transforms import ToTensor, Resize, Compose, Normalize class Dataset4extract(Dataset): def __init__(self, image_paths): self.image_paths = image_paths self.transform = Compose([ ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img = Image.open(self.image_paths[idx]) img = self.transform(img) if img.shape[0] == 4: # ignore the alpha channel # in the image if it exists img = img[:3, :, :] return img image_paths = get_file_list(image_folder, ext='jpg') dataset = Dataset4extract(image_paths) print("Total number of images in the dataset:", len(dataset)) from torch.utils.data import DataLoader dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2) '''-----------------建立模型----------------''' from MSG_GAN.GAN import MSG_GAN depth = 7 msg_gan = MSG_GAN(depth=depth, latent_size=512, use_eql=True, use_ema=True, ema_decay=0.999, device=device) msg_gan.dis.load_state_dict(th.load(model_path)) '''-----------------进行评估----------------''' features = [] from torch.nn.functional import avg_pool2d for (i, batch) in enumerate(dataloader): # 获取多分辨率的图像输入 images = batch.to(device) images = [images] + [avg_pool2d(images, int(np.power(2, i))) for i in range(1, depth)] images = list(reversed(images)) # 把这些图像丢给模型 feature = msg_gan.extract(images) features.append(feature.detach().cpu().numpy()) '''-----------------保存结果----------------''' features = np.concatenate(features, axis=0) save_variables([features], save_path)
def main(args): """ Main function for the script :param args: parsed command line arguments :return: None """ from MSG_GAN.GAN import MSG_GAN from data_processing.DataLoader import FlatDirectoryImageDataset, \ get_transform, get_data_loader, FoldersDistributedDataset, IgnoreLabels from MSG_GAN import Losses as lses # create a data source: if args.pytorch_dataset is None: data_source = FlatDirectoryImageDataset if not args.folder_distributed \ else FoldersDistributedDataset dataset = data_source(args.images_dir, transform=get_transform( (int(np.power(2, args.depth + 1)), int(np.power(2, args.depth + 1))), flip_horizontal=args.flip_augment)) else: dataset_name = args.pytorch_dataset.lower() if dataset_name == "cifar10": dataset = IgnoreLabels( CIFAR10(args.dataset_dir, transform=get_transform( (int(np.power(2, args.depth + 1)), int(np.power(2, args.depth + 1))), flip_horizontal=args.flip_augment), download=True)) else: raise Exception("Unknown dataset requested") data = get_data_loader(dataset, args.batch_size, args.num_workers) print("Total number of images in the dataset:", len(dataset)) # create a gan from these msg_gan = MSG_GAN(depth=args.depth, latent_size=args.latent_size, use_eql=args.use_eql, use_ema=args.use_ema, ema_decay=args.ema_decay, device=device) if args.generator_file is not None: # load the weights into generator print("loading generator_weights from:", args.generator_file) msg_gan.gen.load_state_dict(th.load(args.generator_file)) # print("Generator Configuration: ") # print(msg_gan.gen) if args.shadow_generator_file is not None: # load the weights into generator print("loading shadow_generator_weights from:", args.shadow_generator_file) msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file)) if args.discriminator_file is not None: # load the weights into discriminator print("loading discriminator_weights from:", args.discriminator_file) msg_gan.dis.load_state_dict(th.load(args.discriminator_file)) # print("Discriminator Configuration: ") # print(msg_gan.dis) # create optimizer for generator: gen_params = [{ 'params': msg_gan.gen.style.parameters(), 'lr': args.g_lr * 0.01, 'mult': 0.01 }, { 'params': msg_gan.gen.layers.parameters(), 'lr': args.g_lr }, { 'params': msg_gan.gen.rgb_converters.parameters(), 'lr': args.g_lr }] gen_optim = th.optim.Adam(gen_params, args.g_lr, [args.adam_beta1, args.adam_beta2]) dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr, [args.adam_beta1, args.adam_beta2]) if args.generator_optim_file is not None: print("loading gen_optim_state from:", args.generator_optim_file) gen_optim.load_state_dict(th.load(args.generator_optim_file)) if args.discriminator_optim_file is not None: print("loading dis_optim_state from:", args.discriminator_optim_file) dis_optim.load_state_dict(th.load(args.discriminator_optim_file)) loss_name = args.loss_function.lower() if loss_name == "hinge": loss = lses.HingeGAN elif loss_name == "relativistic-hinge": loss = lses.RelativisticAverageHingeGAN elif loss_name == "standard-gan": loss = lses.StandardGAN elif loss_name == "lsgan": loss = lses.LSGAN elif loss_name == "lsgan-sigmoid": loss = lses.LSGAN_SIGMOID elif loss_name == "wgan-gp": loss = lses.WGAN_GP else: raise Exception("Unknown loss function requested") now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') if args.pytorch_dataset is not None: dataName = 'cifar' elif args.images_dir.find('celeb') != -1: dataName = 'celeb' else: dataName = 'flowers' output_dir = 'output/%s_%s_%s' % \ ('attnmsggan', dataName, timestamp) args.model_dir = output_dir + '/models' args.sample_dir = output_dir + '/images' args.log_dir = output_dir + '/logs' # train the GAN msg_gan.train(data, gen_optim, dis_optim, loss_fn=loss(msg_gan.dis), num_epochs=args.num_epochs, checkpoint_factor=args.checkpoint_factor, data_percentage=args.data_percentage, feedback_factor=args.feedback_factor, num_samples=args.num_samples, sample_dir=args.sample_dir, save_dir=args.model_dir, log_dir=args.log_dir, start=args.start)