def main(args): """ Main function for the script :param args: parsed command line arguments :return: None """ from MSG_GAN.GAN import MSG_GAN from data_processing.DataLoader import FlatDirectoryImageDataset, \ get_transform, get_data_loader, FoldersDistributedDataset from MSG_GAN import Losses as lses # create a data source: data_source = FlatDirectoryImageDataset if not args.folder_distributed \ else FoldersDistributedDataset dataset = data_source(args.images_dir, transform=get_transform( (int(np.power(2, args.depth + 1)), int(np.power(2, args.depth + 1))), flip_horizontal=args.flip_augment)) data = get_data_loader(dataset, args.batch_size, args.num_workers) print("Total number of images in the dataset:", len(dataset)) # create a gan from these msg_gan = MSG_GAN(depth=args.depth, latent_size=args.latent_size, use_eql=args.use_eql, use_ema=args.use_ema, ema_decay=args.ema_decay, device=device) if args.generator_file is not None: # load the weights into generator print("loading generator_weights from:", args.generator_file) msg_gan.gen.load_state_dict(th.load(args.generator_file)) print("Generator Configuration: ") # print(msg_gan.gen) if args.shadow_generator_file is not None: # load the weights into generator print("loading shadow_generator_weights from:", args.shadow_generator_file) msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file)) if args.discriminator_file is not None: # load the weights into discriminator print("loading discriminator_weights from:", args.discriminator_file) msg_gan.dis.load_state_dict(th.load(args.discriminator_file)) print("Discriminator Configuration: ") # print(msg_gan.dis) # create optimizer for generator: gen_optim = th.optim.Adam(msg_gan.gen.parameters(), args.g_lr, [args.adam_beta1, args.adam_beta2]) dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr, [args.adam_beta1, args.adam_beta2]) if args.generator_optim_file is not None: print("loading gen_optim_state from:", args.generator_optim_file) gen_optim.load_state_dict(th.load(args.generator_optim_file)) if args.discriminator_optim_file is not None: print("loading dis_optim_state from:", args.discriminator_optim_file) dis_optim.load_state_dict(th.load(args.discriminator_optim_file)) loss_name = args.loss_function.lower() if loss_name == "hinge": loss = lses.HingeGAN elif loss_name == "relativistic-hinge": loss = lses.RelativisticAverageHingeGAN elif loss_name == "standard-gan": loss = lses.StandardGAN elif loss_name == "lsgan": loss = lses.LSGAN elif loss_name == "lsgan-sigmoid": loss = lses.LSGAN_SIGMOID elif loss_name == "wgan-gp": loss = lses.WGAN_GP else: raise Exception("Unknown loss function requested") # train the GAN msg_gan.train(data, gen_optim, dis_optim, loss_fn=loss(msg_gan.dis), num_epochs=args.num_epochs, checkpoint_factor=args.checkpoint_factor, data_percentage=args.data_percentage, feedback_factor=args.feedback_factor, num_samples=args.num_samples, sample_dir=args.sample_dir, save_dir=args.model_dir, log_dir=args.model_dir, start=args.start)
def main(args): """ Main function for the script :param args: parsed command line arguments :return: None """ from MSG_GAN.GAN import MSG_GAN from data_processing.DataLoader import FlatDirectoryImageDataset, \ get_transform, get_data_loader, FoldersDistributedDataset, IgnoreLabels from MSG_GAN import Losses as lses # create a data source: if args.pytorch_dataset is None: data_source = FlatDirectoryImageDataset if not args.folder_distributed \ else FoldersDistributedDataset dataset = data_source(args.images_dir, transform=get_transform( (int(np.power(2, args.depth + 1)), int(np.power(2, args.depth + 1))), flip_horizontal=args.flip_augment)) else: dataset_name = args.pytorch_dataset.lower() if dataset_name == "cifar10": dataset = IgnoreLabels( CIFAR10(args.dataset_dir, transform=get_transform( (int(np.power(2, args.depth + 1)), int(np.power(2, args.depth + 1))), flip_horizontal=args.flip_augment), download=True)) else: raise Exception("Unknown dataset requested") data = get_data_loader(dataset, args.batch_size, args.num_workers) print("Total number of images in the dataset:", len(dataset)) # create a gan from these msg_gan = MSG_GAN(depth=args.depth, latent_size=args.latent_size, use_eql=args.use_eql, use_ema=args.use_ema, ema_decay=args.ema_decay, device=device) if args.generator_file is not None: # load the weights into generator print("loading generator_weights from:", args.generator_file) msg_gan.gen.load_state_dict(th.load(args.generator_file)) # print("Generator Configuration: ") # print(msg_gan.gen) if args.shadow_generator_file is not None: # load the weights into generator print("loading shadow_generator_weights from:", args.shadow_generator_file) msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file)) if args.discriminator_file is not None: # load the weights into discriminator print("loading discriminator_weights from:", args.discriminator_file) msg_gan.dis.load_state_dict(th.load(args.discriminator_file)) # print("Discriminator Configuration: ") # print(msg_gan.dis) # create optimizer for generator: gen_params = [{ 'params': msg_gan.gen.style.parameters(), 'lr': args.g_lr * 0.01, 'mult': 0.01 }, { 'params': msg_gan.gen.layers.parameters(), 'lr': args.g_lr }, { 'params': msg_gan.gen.rgb_converters.parameters(), 'lr': args.g_lr }] gen_optim = th.optim.Adam(gen_params, args.g_lr, [args.adam_beta1, args.adam_beta2]) dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr, [args.adam_beta1, args.adam_beta2]) if args.generator_optim_file is not None: print("loading gen_optim_state from:", args.generator_optim_file) gen_optim.load_state_dict(th.load(args.generator_optim_file)) if args.discriminator_optim_file is not None: print("loading dis_optim_state from:", args.discriminator_optim_file) dis_optim.load_state_dict(th.load(args.discriminator_optim_file)) loss_name = args.loss_function.lower() if loss_name == "hinge": loss = lses.HingeGAN elif loss_name == "relativistic-hinge": loss = lses.RelativisticAverageHingeGAN elif loss_name == "standard-gan": loss = lses.StandardGAN elif loss_name == "lsgan": loss = lses.LSGAN elif loss_name == "lsgan-sigmoid": loss = lses.LSGAN_SIGMOID elif loss_name == "wgan-gp": loss = lses.WGAN_GP else: raise Exception("Unknown loss function requested") now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') if args.pytorch_dataset is not None: dataName = 'cifar' elif args.images_dir.find('celeb') != -1: dataName = 'celeb' else: dataName = 'flowers' output_dir = 'output/%s_%s_%s' % \ ('attnmsggan', dataName, timestamp) args.model_dir = output_dir + '/models' args.sample_dir = output_dir + '/images' args.log_dir = output_dir + '/logs' # train the GAN msg_gan.train(data, gen_optim, dis_optim, loss_fn=loss(msg_gan.dis), num_epochs=args.num_epochs, checkpoint_factor=args.checkpoint_factor, data_percentage=args.data_percentage, feedback_factor=args.feedback_factor, num_samples=args.num_samples, sample_dir=args.sample_dir, save_dir=args.model_dir, log_dir=args.log_dir, start=args.start)