Exemple #1
0
 def __init__(self, opt):
     super(TrainingModel, self).__init__(opt)
     utils.print_options(opt)
     self.train_main_loader = utils.get_data_loader(opt,
                                                    train=True,
                                                    main=True)
     self.train_tuning_blocks_loader = utils.get_data_loader(opt,
                                                             train=True,
                                                             main=False)
     self.main_disc = Discriminator().to(self.device)
     if self.opt.tuning_blocks_disc_same_as_main_disc:
         self.tuning_blocks_disc = self.main_disc
     else:
         self.tuning_blocks_disc = Discriminator().to(self.device)
     self.main_gen_optimizer = Adam(self.net.main.parameters(),
                                    lr=opt.gen_learning_rate_main,
                                    betas=(0.5, 0.999))
     self.tuning_blocks_gen_optimizer = Adam(
         self.net.tuning_blocks.parameters(),
         lr=opt.gen_learning_rate_tuning_blocks,
         betas=(0.5, 0.999))
     self.main_disc_optimizer = Adam(self.main_disc.parameters(),
                                     lr=opt.disc_learning_rate_main,
                                     betas=(0.5, 0.999))
     self.tuning_blocks_disc_optimizer = Adam(
         self.tuning_blocks_disc.parameters(),
         lr=opt.disc_learning_rate_tuning_blocks,
         betas=(0.5, 0.999))
     self.eval_tensor = torch.randn(
         (opt.eval_noise_batch_size,
          self.opt.z_size)).view(-1, self.opt.z_size, 1, 1).to(self.device)
     self.criterion = nn.BCELoss().to(self.device)
     self.disc = None
     self.train_loader = None
Exemple #2
0
 def __init__(self, opt):
     super(TrainingModel, self).__init__(opt)
     self.vgg_mean = [0.485, 0.456, 0.406]
     self.vgg_std = [0.229, 0.224, 0.225]
     self.main_optimizer = Adam(self.net.main.parameters(), opt.learning_rate_main)
     if opt.network_version == 'normal':
         self.tuning_blocks_optimizer = Adam(self.net.tuning_blocks.parameters(), opt.learning_rate_blocks)
     elif opt.network_version == 'dual':
         self.tuning_blocks_lower_optimizer = Adam(self.net.tuning_blocks_lower.parameters(), opt.learning_rate_blocks)
         self.tuning_blocks_higher_optimizer = Adam(self.net.tuning_blocks_higher.parameters(), opt.learning_rate_blocks)
     self.style_transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize(mean=self.vgg_mean, std=self.vgg_std)
     ])
     self.eval_transform = transforms.ToTensor()
     if opt.vgg_output:
         self.eval_transform = transforms.Compose([
             self.eval_transform,
             transforms.Normalize(mean=self.vgg_mean, std=self.vgg_std)
         ])
     self.vgg = LossNetwork(opt).to(self.device)
     self.train_loader = utils.get_data_loader(opt.train_data_path, opt.batch_size, opt.image_size, train=True, normalize=opt.vgg_output)
     self.val_loader = utils.get_data_loader(opt.val_data_path, opt.batch_size, opt.image_size, train=False, normalize=opt.vgg_output)
     self.mse_loss = torch.nn.MSELoss().to(self.device)
     self.style_wights = [opt.style_wight0, opt.style_wight1, opt.style_wight2, opt.style_wight3]
     self.style_image_path = None
Exemple #3
0
    # training configuration
    with open("train_config.json", "r") as f:
        train_config = json.load(f)
        args = Namespace(**train_config)

    # initializing networks and optimizers
    if args.type == "DCGAN":
        G, D = utils.get_gan(GANType.DCGAN, device)
        G_optim, D_optim = utils.get_optimizers(G, D)
    elif args.type == "SN_DCGAN":
        G, D = utils.get_gan(GANType.SN_DCGAN, device, args.n_power_iterations)
        G_optim, D_optim = utils.get_optimizers(G, D)

    # initializing loader for data
    data_loader = utils.get_data_loader(args.batch_size, args.img_size)

    # setting up loss and GT
    adversarial_loss = nn.BCELoss()
    real_gt, fake_gt = utils.get_gt(args.batch_size, device)

    # for logging
    log_batch_size = 25
    log_noise = utils.get_latent_batch(log_batch_size, device)
    D_loss_values, G_loss_values = [], []
    img_count = 0

    # responsible for dumping data in TensorBoard
    writer = SummaryWriter(paths.log_path)

    print("training started...")
Exemple #4
0
    eval_step = 1  # epochs

    manual_seed = 8888
    alpha = 0

    # params for optimizing models
    lr = 2e-4


params = Config()

# init random seed
init_random_seed(params.manual_seed)

# load dataset
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root,
                                  params.batch_size)
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root,
                                  params.batch_size)

# load dann model
dann = init_model(net=AlexModel(), restore=None)

# train dann model
print("Start training dann model.")

if not (dann.restored and params.dann_restore):
    dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
                      tgt_data_loader)

print('done')
Exemple #5
0
    manual_seed = 8888
    alpha = 0

    # params for optimizing models
    lr = 2e-4


params = Config()

# init random seed
init_random_seed(params.manual_seed)

# load dataset
src_data_loader = get_data_loader(params.src_dataset,
                                  params.dataset_root,
                                  params.batch_size,
                                  train=True)
src_data_loader_eval = get_data_loader(params.src_dataset,
                                       params.dataset_root,
                                       params.batch_size,
                                       train=False)
tgt_data_loader = get_data_loader(params.tgt_dataset,
                                  params.dataset_root,
                                  params.batch_size,
                                  train=True)
tgt_data_loader_eval = get_data_loader(params.tgt_dataset,
                                       params.dataset_root,
                                       params.batch_size,
                                       train=False)

# load dann model
    log_step = 10  # iter
    eval_step = 1  # epoch
    save_step = 500
    manual_seed = None

params = Config()

if __name__ == '__main__':
    # init random seed
    init_random_seed(params.manual_seed)

    # init device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    src_data_loader = get_data_loader(
        params.src_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size, train=True)
    src_data_loader_eval = get_data_loader(
        params.src_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size_eval, train=False)
    tgt_data_loader = get_data_loader(
        params.tgt_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size, train=True)
    tgt_data_loader_eval = get_data_loader(
        params.tgt_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size_eval, train=False)

    # load models
    model = ResModel().to(device)

    # training model
    print("training model")
    if not (model.restored and params.model_trained):
        model = train_src(model, src_data_loader, src_data_loader_eval,
                          tgt_data_loader, tgt_data_loader_eval, device, params)