コード例 #1
0
ファイル: trainer.py プロジェクト: zhaoyongboy/pggan-pytorch
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.use_cuda = False
            torch.set_default_tensor_type('torch.FloatTensor')

        self.nz = config.nz
        self.optimizer = config.optimizer

        self.resl = 2  # we start from 2^2 = 4
        self.max_resl = config.max_resl
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = 0
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen': None, 'dis': None}
        self.complete = {'gen': 0, 'dis': 0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.trns_tick = self.config.trns_tick
        self.stab_tick = self.config.stab_tick

        # network and cirterion
        self.G = net.Generator(config)
        self.D = net.Discriminator(config)
        print('Generator structure: ')
        print(self.G.model)
        print('Discriminator structure: ')
        print(self.D.model)
        self.mse = torch.nn.MSELoss()
        if self.use_cuda:
            self.mse = self.mse.cuda()
            torch.cuda.manual_seed(config.random_seed)
            if config.n_gpu == 1:
                #self.G = self.G.cuda()
                #self.D = self.D.cuda()         # It seems simply call .cuda() on the model does not function. PyTorch bug when we use modules.
                self.G = torch.nn.DataParallel(self.G).cuda(device_id=0)
                self.D = torch.nn.DataParallel(self.D).cuda(device_id=0)
            else:
                gpus = []
                for i in range(config.n_gpu):
                    gpus.append(i)
                self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()

        # define tensors, and get dataloader.
        self.renew_everything()

        # tensorboard
        self.use_tb = config.use_tb
        if self.use_tb:
            self.tb = tensorboard.tf_recorder()
コード例 #2
0
    def __init__(self,config):
        self.config=config
        if torch.cuda.is_available():
            self.use_cuda=True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.use_cuda=False
            torch.set_default_tensor_type('torch.FloatTensor')

        self.nz = config.nz
        self.optimizer = config.optimizer
        self.resl = 2  # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resl = config.max_resl
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = 0
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen': None, 'dis': None}
        self.complete = {'gen': 0, 'dis': 0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift
        self.loader=DL.dataloader(config)
        self.LAMBDA=2

        # network
        self.G = network.Generator(config)
        self.D = network.Discriminator(config)
        print('Generator structure: ')
        print(self.G.model)
        print('Discriminator structure: ')
        print(self.D.model)
        if self.use_cuda:
            torch.cuda.manual_seed(config.random_seed)
            #self.G = self.G.cuda()
            #self.D = self.D.cuda()
            if config.n_gpu==1:
                self.G=torch.nn.DataParallel(self.G).cuda(device=0)
                self.D=torch.nn.DataParallel(self.D).cuda(device=0)
            else:
                gpus=[]
                for i in range(config.n_gpu):
                    gpus.append(i)
                self.G=torch.nn.DataParallel(self.G,device_ids=gpus).cuda()
                self.D=torch.nn.DataParallel(self.D,device_ids=gpus).cuda()
        self.renew_everything()
        self.use_tb=config.use_tb
        if self.use_tb:
            self.tb=tensorboard.tf_recorder()
コード例 #3
0
ファイル: test.py プロジェクト: lihungte96/pggan-pytorch
    def __init__(self, config, loader, path):
        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.use_cuda = False
            torch.set_default_tensor_type('torch.FloatTensor')

        self.nz = config.nz

        self.resl = config.resl  # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resl = config.max_resl
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = config.globalTick
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen': None, 'dis': None}
        self.complete = {'gen': 0, 'dis': 0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift
        self.load = config.load
        self.batch = 1

        # network and cirterion
        self.D = net.Discriminator(config)

        if self.use_cuda:
            if config.n_gpu == 1:
                self.D = torch.nn.DataParallel(self.D).cuda(device=0)
            else:
                gpus = []
                for i in range(config.n_gpu):
                    gpus.append(i)
                self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()

        #self.renew_everything()
        if self.load:
            self.load_snapshot('repo/model')
        self.renew_everything()

        if self.use_cuda:
            if config.n_gpu == 1:
                self.D.cuda(device=0)

        self.loader = loader
        self.path = path
コード例 #4
0
ファイル: trainer.py プロジェクト: celdeldel/chaise_1
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.use_cuda = False
            torch.set_default_tensor_type('torch.FloatTensor')

        self.nz = config.nz
        self.optimizer = config.optimizer

        self.resl = config.start_res  # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resl = config.max_resl
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = 0
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen': None, 'dis': None}
        self.complete = {'gen': 0, 'dis': 0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift

        # network and criterion
        if config.start_res == 2:
            self.G = net.Generator(config)
            self.D = net.Discriminator(config)
        else:
            self.G, self.D = g_d_interpolated.recup_nets(config)
        self.mse = torch.nn.MSELoss()
        if self.use_cuda:
            self.mse = self.mse.cuda()
            torch.cuda.manual_seed(config.random_seed)
            """
            if config.n_gpu==1:
                self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                self.D = torch.nn.DataParallel(self.D).cuda(device=0)
            else:
                gpus = []
                for i  in range(config.n_gpu):
                    gpus.append(i)
                self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()  
            """

        # define tensors, ship model to cuda, and get dataloader.
        self.renew_everything()
コード例 #5
0
ファイル: generate.py プロジェクト: juWuBabaaaa/gan
    def __init__(self, sample_dimension, noise_dimension, bs):
        self.sample_dimension = sample_dimension
        self.noise_dimension = noise_dimension
        self.bs = bs

        self.D = network.Discriminator(sample_dimension=sample_dimension)
        self.G = network.Generator(noise_dimension=noise_dimension,
                                   sample_dimension=sample_dimension)

        self.optimizer_d = torch.optim.RMSprop(self.D.parameters())
        self.optimizer_g = torch.optim.RMSprop(self.G.parameters())

        self.criterion = nn.BCELoss()
コード例 #6
0
ファイル: trainer.py プロジェクト: anhvth/pggan-pytorch
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.use_cuda = False
            torch.set_default_tensor_type('torch.FloatTensor')

        self.nz = config.nz
        self.optimizer = config.optimizer

        self.resolution = 2  # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resolution = config.max_resolution
        self.transition_tick = config.transition_tick
        self.stablize_tick = config.stablize_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = 0
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen': None, 'dis': None}
        self.complete = {'gen': 0, 'dis': 0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift

        # network and cirterion
        self.G = nn.DataParallel(net.Generator(config))
        self.D = nn.DataParallel(net.Discriminator(config))
        print('Generator structure: ')
        print(self.G.module.model)
        print('Discriminator structure: ')
        print(self.D.module.model)
        self.mse = torch.nn.MSELoss()
        self.renew_everything()

        # tensorboard
        self.use_tb = config.use_tb
        if self.use_tb:
            self.tb = tensorboard.tf_recorder()

        if config.pretrained is not None:
            self.load_pretrained(config.pretrained)
コード例 #7
0
 def __init__(self, opt):
     super(TrainModel, self).__init__()
     self.isTrain = opt.isTrain
     if self.isTrain:
         self.netE = network.Encoder(opt.input_nc, opt.ngf,
                                     opt.n_downsampling)
         self.netE.apply(network.weights_init)
         self.netG = network.Decoder(opt.input_nc, opt.output_nc, opt.ngf,
                                     opt.n_downsampling)
         self.netG.apply(network.weights_init)
         self.netD = network.Discriminator(opt.input_nc, opt.ngf,
                                           opt.n_layer)
         self.netD.apply(network.weights_init)
         self.criterionGAN = nn.BCELoss()
         self.criterionKL = network.KLLoss
         self.criterionRecon = network.ReconLoss
     else:
         pass
コード例 #8
0
def recup_nets(config):
    use_cuda = True
    checkpoint_path_g = config.checkpoint_generator
    checkpoint_path_d = config.checkpoint_discriminator

    # load trained model.
    model_g = net.Generator(config)
    model_d = net.Discriminator(config)
    if use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        model_g = torch.nn.DataParallel(model_g).cuda(device=0)
        model_d = torch.nn.DataParallel(model_d).cuda(device=0)
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    for resl in range(3, config.start_res + 1):
        model_g.module.grow_network(resl)
        model_d.module.grow_network(resl)
        model_g.module.flush_network()
        model_d.module.flush_network()
    print('generator :')
    print(model_g)
    print('discriminator :')
    print(model_d)

    print('load generator from checkpoint  ... {}'.format(checkpoint_path_g))
    print(
        'load discriminator from checkpoint ... {}'.format(checkpoint_path_d))
    checkpoint_g = torch.load(os.path.join('repo/model', checkpoint_path_g))
    checkpoint_d = torch.load(os.path.join('repo/model', checkpoint_path_d))
    print(type(checkpoint_g['state_dict']))
    print(type(checkpoint_d['state_dict']))
    model_g.module.load_state_dict(checkpoint_g['state_dict'], False)
    model_d.module.load_state_dict(checkpoint_d['state_dict'], False)

    return model_g, model_d
コード例 #9
0
def extract_feature(path):
    args = parser()

    dis = network.Discriminator(args.depth)
    print('Loading Discirminator Model from ' + args.dis)
    serializers.load_npz(args.dis, dis)

    img_list = os.listdir(args.img)
    img_list = sorted(img_list)
    if img_list[0] == '.DS_Store':
        del img_list[0]

    IMG_PATH = [os.path.join(args.img, name) for name in img_list]
    for i in tqdm(range(args.num)):
        im = get_example(os.path.join(args.img, IMG_PATH[i]), args.depth, None)
        im = im[np.newaxis, :, :, :]
        out_dis, _ = dis(im, alpha=1.0)
        out_dis = np.ravel(out_dis.data)
        try:
            output = np.vstack((output, out_dis))
        except:
            output = out_dis

    return output
コード例 #10
0
ファイル: train.py プロジェクト: runzhang1997/master-thesis
    else:
        train10, train20, label = iterator.get_next()

    # set up learning rate.
    '''
    global_step_generator = tf.Variable(0, trainable=False)
    global_step_discriminator = tf.Variable(0, trainable=False)
    lr_discriminator = tf.train.exponential_decay(args.discrimator_learning_rate, 
        global_step_discriminator, 8e4, 0.5, staircase=False)
    lr_generator = tf.train.exponential_decay(args.generator_learning_rate, 
        global_step_generator, 4e4, 0.5, staircase=False)  
    '''
    # Set up models
    discriminator = network.Discriminator(loss_type=args.gan_type,
                                          image_size=args.patch_size,
                                          batch_size=args.batch_size,
                                          norm=args.norm,
                                          run_60=args.run_60)
    generator = network.Generator(adversarial_loss=args.gan_type,
                                  content_loss=args.contentloss,
                                  batch_size=args.batch_size,
                                  discriminator=discriminator,
                                  norm=False,
                                  adv_weight=args.adv_weight,
                                  relu=args.relu,
                                  run_60=args.run_60)

    # Generator
    if args.run_60:
        g_y_pred = generator.forward(train10, train20, train60)
    else:
コード例 #11
0
ファイル: trainer.py プロジェクト: deeptechlabs/pggan-pytorch
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.use_cuda = False
            torch.set_default_tensor_type('torch.FloatTensor')
        
        self.nz = config.nz
        self.optimizer = config.optimizer

        self.resl = 2           # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resl = config.max_resl
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = 0
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen':None, 'dis':None}
        self.complete = {'gen':0, 'dis':0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift
        self.use_captions = config.use_captions
        self.gan_type = config.gan_type
        self.lambda = config.lambda

        if self.use_captions:
            self.ncap = config.ncap
        
        # network and cirterion
        self.G = net.Generator(config, use_captions=self.use_captions)
        self.D = net.Discriminator(config, use_captions=self.use_captions)
        print ('Generator structure: ')
        print(self.G.model)
        print ('Discriminator structure: ')
        print(self.D.model)

        if self.gan_type == 'lsgan':
            self.mse = torch.nn.MSELoss()

        if self.use_cuda:
            if self.gan_type == 'lsgan':
                self.mse = self.mse.cuda()
            torch.cuda.manual_seed(config.random_seed)

            if config.n_gpu==1:
                # if only on single GPU
                self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                self.D = torch.nn.DataParallel(self.D).cuda(device=0)
            else:
                # if more we're doing multiple GPUs
                gpus = []
                for i  in range(config.n_gpu):
                    gpus.append(i)
                self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()  

        # define tensors, ship model to cuda, and get dataloader.
        self.renew_everything()
        
        # tensorboard
        self.use_tb = config.use_tb
        if self.use_tb:
            self.tb = tensorboard.tf_recorder()
コード例 #12
0
        # construct DataLoader list
        if cuda:
            torch_dataset = torch.utils.data.TensorDataset(
                torch.FloatTensor(gene_exp).cuda(), torch.LongTensor(labels).cuda())
        else:
            torch_dataset = torch.utils.data.TensorDataset(
                torch.FloatTensor(gene_exp), torch.LongTensor(labels))
        data_loader = torch.utils.data.DataLoader(torch_dataset, batch_size=batch_size,
                                                        shuffle=True, drop_last=True)
        batch_loader_dict[i+1] = data_loader

    # create model
    encoder = models.Encoder(num_inputs=num_inputs)
    decoder_a = models.Decoder_a(num_inputs=num_inputs)
    decoder_b = models.Decoder_b(num_inputs=num_inputs)
    discriminator = models.Discriminator(num_inputs=num_inputs)

    if cuda:
        encoder.cuda()
        decoder_a.cuda()
        decoder_b.cuda()
        discriminator.cuda()

    # training
    loss_total_list = []  # list of total loss
    loss_reconstruct_list = []
    loss_transfer_list = []
    loss_classifier_list = []
    for epoch in range(1, num_epochs + 1):
        log_interval = config.log_interval
        base_lr = config.base_lr
コード例 #13
0
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

# Loss weight for gradient penalty
lambda_gp = 10

# Initialize generator and discriminator
generator = network.Generator(image_size=opt.img_size,
                              z_dim=opt.latent_dim,
                              conv_dim=opt.conv_dim,
                              selfattn=True)
discriminator = network.Discriminator(image_size=opt.img_size,
                                      conv_dim=opt.conv_dim,
                                      selfattn=True)

if cuda:
    generator.cuda()
    discriminator.cuda()

# Configure data loader
dataloader = DataLoader(datasets.CIFAR10(
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]),
),
コード例 #14
0
ファイル: representation_train.py プロジェクト: eudtr/EUDTR
    def __init__(self,
                 timesteps=1,
                 nb_random_samples=10, 
                 batch_size=1, 
                 epochs=30, 
                 lr=1e-3,
                 in_channels=1,
                 channels=10, 
                 depth=1,
                 kernel_size=3,
                 reduced_size=10, 
                 out_channels=10,
                 negative_penalty=1,
                 n_clusters=1,
                 n_init=10, **kwargs):
        
        """
        EUDTR: Enhanced Unsupervised Deep Temporal Representation learning model
        
        Args:
            timesteps: Length of time series.
            nb_random_samples: Number of randomly chosen intervals to select the
                               positiva and negative sample in the loss.
            batch_size: Batch size used during the training of the network.
            epochs: Number of optimization steps to perform for the training of
                    the network.
            lr: Learning rate of the Adam optimizer used to train the network.
            in_channels: Number of input channels.
            channels: Number of channels manipulated in the causal CNN.
            depth: Depth of the causal CNN.
            kernel_size: Kernel size of the applied non-residual convolutions.
            reduced_size: Fixed length to which the output time series of the
                          causal CNN is reduced.
            out_channels: Number of features in the final output.
            negative_penalty: Multiplicative coefficient for the negative sample
                              loss.
            n_clusters: Number of clusters.
            n_init: Number of time the k-means algorithm will be run with different
                    centroid seeds.
        """

        self.encoder = network.CausalCNNEncoder(in_channels, channels, depth, reduced_size, out_channels, kernel_size)
        self.decoder = network.Discriminator(timesteps, out_channels, in_channels, reduced_size, kernel_size, stride=1, padding=1, out_padding=0)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.encoder = self.encoder.double().to(self.device)
        self.decoder = self.decoder.double().to(self.device)

        self.mi_loss = reconstruction_loss.MILoss(self.device).to(self.device)

        self.batch_size = batch_size
        self.epochs = epochs

        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=lr)
        self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), lr=lr)

        self.negative_penalty = negative_penalty
        self.out_channels = out_channels

        self.n_clusters = n_clusters
        self.n_init = n_init

        self.nb_random_samples = nb_random_samples
コード例 #15
0
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            tqdm.write('Using GPU.')
        else:
            self.use_cuda = False
            torch.set_default_tensor_type('torch.FloatTensor')
        
        self.nz = config.nz
        self.optimizer = config.optimizer

        self.minibatch_repeat = 4
        self.resl = 2         # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resl = config.max_resl
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = 0
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen':None, 'dis':None}
        self.complete = {'gen':0, 'dis':0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift
        self.flag_wgan = self.config.flag_wgan
        
        # network and cirterion
        self.G = net.Generator(config)
        self.Gs = net.Generator(config)
        self.D = net.Discriminator(config)
        self.mse = torch.nn.MSELoss()
        if self.use_cuda:
            self.mse = self.mse.cuda()
            torch.cuda.manual_seed(int(time.time()))
            if config.n_gpu==1:
                self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                self.Gs = torch.nn.DataParallel(self.Gs).cuda(device=0)
                self.D = torch.nn.DataParallel(self.D).cuda(device=0)
            else:
                gpus = []
                for i  in range(config.n_gpu):
                    gpus.append(i)
                self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()  

        self.gen_ckpt = config.gen_ckpt
        self.gs_ckpt = config.gs_ckpt
        self.dis_ckpt = config.dis_ckpt
        if self.gen_ckpt != '' and self.dis_ckpt != '':
            pattern = '{}gen_R{}_T{}.pth.tar'
            parsed = parse(pattern, self.gen_ckpt)
            restore_resl = int(parsed[1])
            restore_tick = int(parsed[2])
            # Restore the network structure.
            for resl in xrange(3, restore_resl+1):
                self.G.module.grow_network(resl)
                self.Gs.module.grow_network(resl)
                self.D.module.grow_network(resl)
                if resl < restore_resl:
                    self.G.module.flush_network()
                    self.Gs.module.flush_network()
                    self.D.module.flush_network()
                    
            # for _ in xrange(int(self.resl), restore_resl):
            #     self.lr = self.lr * float(self.config.lr_decay)
            print(
                "Restored resolution", restore_resl, 
                "Restored global tick", restore_tick, 
                "Restored learning rate", self.lr)
            self.resl = restore_resl
            self.globalTick = restore_tick
            # Restore the network setting.
            if self.resl != 2:
                self.phase = 'stab'

        # define tensors, ship model to cuda, and get dataloader.
        self.renew_everything()
        if self.gen_ckpt != '' and self.dis_ckpt != '':
            self.G.module.flush_network()
            self.Gs.module.flush_network()
            self.D.module.flush_network()

            self.globalIter = floor(self.globalTick * self.TICK / (self.loader.batchsize * self.minibatch_repeat))
            gen_ckpt = torch.load(self.gen_ckpt)
            gs_ckpt = torch.load(self.gs_ckpt)
            dis_ckpt = torch.load(self.dis_ckpt)
            self.opt_d.load_state_dict(dis_ckpt['optimizer'])
            self.opt_g.load_state_dict(gen_ckpt['optimizer'])
            print('Optimizer restored.')
            self.resl = gen_ckpt['resl']
            self.G.module.load_state_dict(gen_ckpt['state_dict'])
            self.Gs.module.load_state_dict(gs_ckpt['state_dict'])
            self.D.module.load_state_dict(dis_ckpt['state_dict'])
            print('Model weights restored.')
            
            gen_ckpt = None
            dis_ckpt = None
            gs_ckpt = None

        print ('Generator structure: ')
        print(self.G)
        print ('Discriminator structure: ')
        print(self.D)

        # tensorboard
        self.use_tb = config.use_tb
        if self.use_tb:
            self.tb = tensorboard.tf_recorder()
コード例 #16
0
ファイル: trainer.py プロジェクト: taflahi/pytorch_enhance
def train(enhancer, mode, param):
    # create initial discriminator
    disc = network.Discriminator(param['discriminator-size'])
    assert disc.channels == enhancer.discriminator.channels

    seed_size = param['image-size'] // param['zoom']
    images = np.zeros(
        (param['batch-size'], 3, param['image-size'], param['image-size']),
        dtype=np.float32)
    seeds = np.zeros((param['batch-size'], 3, seed_size, seed_size),
                     dtype=np.float32)

    loader.copy(images, seeds)
    # initial lr
    lr = network.decay_learning_rate(param['learning-rate'], 75, 0.5)

    # optimizer for generator
    opt_gen = optim.Adam(enhancer.generator.parameters(), lr=0)
    opt_disc = optim.Adam(disc.parameters(), lr=0)

    try:
        average, start = None, time.time()
        for epoch in range(param['epochs']):
            adversary_weight = 5e2

            total, stats = None, None

            l_r = next(lr)
            if epoch >= param['generator-start']:
                network.update_optimizer_lr(opt_gen, l_r)
            if epoch >= param['discriminator-start']:
                network.update_optimizer_lr(opt_disc, l_r)

            for step in range(param['epoch-size']):
                enhancer.zero_grad()
                disc.zero_grad()

                loader.copy(images, seeds)

                # run full network once
                gen_out, c12, c22, c32, c52, disc_out = enhancer(images, seeds)

                # clone discriminator on the full network
                enhancer.clone_discriminator_to(disc)

                # output of new cloned network (maybe you can assert it to
                # equal disc_out)
                disc_out2 = disc(c12.detach(), c22.detach(), c32.detach())
                disc_out_numpy = disc_out2.data.cpu().numpy(
                ) if torch.cuda.is_available() else disc_out2.data.numpy()

                disc_out_mean = np.mean(disc_out_numpy, axis=(1, 2, 3))

                stats = stats + disc_out_mean if stats is not None else disc_out_mean

                # compute generator loss
                if mode == 'pretrain':
                    gen_loss = network.loss_perceptual(c22[:param['batch-size']], c22[param['batch-size']:]) * param['perceptual-weight'] \
                        + network.loss_total_variation(gen_out) * param['smoothness-weight'] \
                        + network.loss_adversarial(disc_out[param['batch-size']:]) * adversary_weight
                else:
                    gen_loss = network.loss_perceptual(c52[:param['batch-size']], c52[param['batch-size']:]) * param['perceptual-weight'] \
                        + network.loss_total_variation(gen_out) * param['smoothness-weight'] \
                        + network.loss_adversarial(disc_out[param['batch-size']:]) * adversary_weight

                # compute discriminator loss
                disc_loss = network.loss_discriminator(
                    disc_out2[:param['batch-size']],
                    disc_out2[param['batch-size']:])

                gen_loss_data = gen_loss.data.cpu().numpy(
                ) if torch.cuda.is_available() else gen_loss.data.numpy()

                total = total + gen_loss_data if total is not None else gen_loss_data

                average = gen_loss_data if average is None else average * \
                    0.95 + 0.05 * gen_loss_data
                print('↑' if gen_loss_data > average else '↓',
                      end='',
                      flush=True)

                # update parameters step

                gen_loss.backward()
                disc_loss.backward()

                torch.nn.utils.clip_grad_norm(disc.parameters(), 5)

                opt_gen.step()
                opt_disc.step()

                # rebuild real discriminator from clone
                enhancer.assign_back_discriminator(disc)

            total /= param['epoch-size']
            stats /= param['epoch-size']

            print('\nOn Epoch: ' + str(epoch))

            print('Generator Loss: ')
            print(total)

            real, fake = stats[:param['batch-size']], stats[
                param['batch-size']:]
            print('  - discriminator', real.mean(),
                  len(np.where(real > 0.5)[0]), fake.mean(),
                  len(np.where(fake < -0.5)[0]))

            if epoch == param['adversarial-start'] - 1:
                print('  - generator now optimizing against discriminator.')
                adversary_weight = param['adversary-weight']

            # Then save every several epochs
            if epoch % 10 == 0:
                torch.save(enhancer.state_dict(),
                           'model/model_' + mode + '.pth')

        torch.save(enhancer.state_dict(), 'model/model_' + mode + '.pth')
        print("Training ends after: " + str(float(time.time() - start)) +
              " seconds")
    except KeyboardInterrupt:
        pass
コード例 #17
0
ファイル: trainer.py プロジェクト: mikanCan/PG-GAN
    def __init__(self, config):
        self.config = config

        # tensorboard
        self.use_tb = config.use_tb
        self.tb = None

        # dir setup
        self.save_dir = "log/"
        if len(os.listdir(self.save_dir)) == 0:
            os.mkdir(self.save_dir + 'act_{:05}'.format(0))
        for i in range(1, 1000):
            if "act_{:05}".format(i) in os.listdir(self.save_dir):
                i += 1
                continue
            self.save_dir += "act_{:05}/".format(i)
            os.mkdir(self.save_dir)
            conf = self.save_dir + "conf.txt"
            with open(conf, mode='w') as f:
                for k, v in vars(config).items():
                    f.writelines('{} : {} \n'.format(k, v))

            if self.use_tb:
                self.tb = tensorboard.tf_recorder(config)
            break
        if not os.path.exists(self.save_dir + 'images'):
            os.mkdir(self.save_dir + 'images')
        if not os.path.exists(self.save_dir + 'model'):
            os.mkdir(self.save_dir + 'model')
        print(self.save_dir)

        if torch.cuda.is_available():
            print("cuda")
            self.use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.use_cuda = False
            torch.set_default_tensor_type('torch.FloatTensor')

        self.nz = config.nz
        self.optimizer = config.optimizer

        self.resl = 2  # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resl = config.max_resl
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.globalIter = 0
        self.globalTick = 0
        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {'gen': None, 'dis': None}
        self.complete = {'gen': 0, 'dis': 0}
        self.phase = 'init'
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift

        # network and cirterion
        self.G = net.Generator(config)
        self.D = net.Discriminator(config)
        print('Generator structure: ')
        print(self.G.model)
        print('Discriminator structure: ')
        print(self.D.model)
        self.mse = torch.nn.MSELoss()
        self.l1 = torch.nn.L1Loss()
        if self.use_cuda:
            self.mse = self.mse.cuda()
            self.l1 = self.l1.cuda()
            torch.cuda.manual_seed(config.random_seed)
            if config.n_gpu == 1:
                self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                self.D = torch.nn.DataParallel(self.D).cuda(device=0)
            else:
                gpus = []
                for i in range(config.n_gpu):
                    gpus.append(i)
                self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()

                # define tensors, ship model to cuda, and get dataloader.
        self.renew_everything()
コード例 #18
0
    transform=transforms.Compose([
        transforms.Resize((img_size, img_size), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
                                         batch_size=batch_size,
                                         shuffle=True)

# GAN Loss function
adversarial_loss = nn.MSELoss()

# Initialize generator and discriminator
generator = network.Generator(latent_dim=latent_dim,
                              classes=n_classes,
                              channels=n_channels)
discriminator = network.Discriminator(classes=n_classes, channels=n_channels)

# Label embedding
label_emb = nn.Embedding(n_classes, n_classes)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    label_emb.cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=lr,
                               betas=(b1, b2))
コード例 #19
0
ファイル: run.py プロジェクト: yvonwin/talking-heads
def meta_train(device, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    # GPU / CPU --------------------------------------------------------------------------------------------------------
    if device is not None and device != 'cpu':
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(device)
        logging.info(f'Running on GPU: {torch.cuda.current_device()}.')
    else:
        dtype = torch.FloatTensor
        logging.info(f'Running on CPU.')

    # DATASET-----------------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    dataset = VoxCelebDataset(root=dataset_path,
                              extension='.vid',
                              shuffle=False,
                              shuffle_frames=True,
                              transform=transforms.Compose([
                                  transforms.Resize(config.IMAGE_SIZE),
                                  transforms.CenterCrop(config.IMAGE_SIZE),
                                  transforms.ToTensor(),
                                  transforms.Normalize([0.485, 0.456, 0.406],
                                                       [0.229, 0.224, 0.225]),
                              ]))

    # NETWORK ----------------------------------------------------------------------------------------------------------

    E = network.Embedder().type(dtype)
    G = network.Generator().type(dtype)
    D = network.Discriminator(143000).type(dtype)

    if continue_id is not None:
        E = load_model(E, continue_id)
        G = load_model(G, continue_id)
        D = load_model(D, continue_id)

    optimizer_E_G = Adam(params=list(E.parameters()) + list(G.parameters()),
                         lr=config.LEARNING_RATE_E_G)
    optimizer_D = Adam(params=D.parameters(), lr=config.LEARNING_RATE_D)

    criterion_E_G = network.LossEG(device, feed_forward=True)
    criterion_D = network.LossD(device)

    # TRAINING LOOP ----------------------------------------------------------------------------------------------------
    logging.info(
        f'Starting training loop. Epochs: {config.EPOCHS} Dataset Size: {len(dataset)}'
    )

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()
        batch_durations = []

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):
            batch_start = datetime.now()

            # Put one frame aside (frame t)
            t = video.pop()

            # Calculate average encoding vector for video
            e_vectors = []
            for s in video:
                x_s = s['frame'].type(dtype)
                y_s = s['landmarks'].type(dtype)
                e_vectors.append(E(x_s, y_s))
            e_hat = torch.stack(e_vectors).mean(dim=0)

            # Generate frame using landmarks from frame t
            x_t = t['frame'].type(dtype)
            y_t = t['landmarks'].type(dtype)
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat, D.W[:, i],
                                     D_act, D_act_hat)
            loss_D = criterion_D(r_x, r_x_hat)
            loss = loss_E_G + loss_D
            loss.backward(retain_graph=True)

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            r_x_hat, D_act_hat = D(G(y_t, e_hat), y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat)
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()
            batch_durations.append(batch_end - batch_start)
            # SHOW PROGRESS --------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0 or batch_num == 0:
                avg_time = sum(batch_durations,
                               timedelta(0)) / len(batch_durations)
                logging.info(
                    f'Epoch {epoch+1}: [{batch_num + 1}/{len(dataset)}] | '
                    f'Avg Time: {avg_time} | '
                    f'Loss_E_G = {loss_E_G.item():.4} Loss_D {loss_D.item():.4}'
                )
                logging.debug(
                    f'D(x) = {r_x.item():.4} D(x_hat) = {r_x_hat.item():.4}')

            # SAVE IMAGES ----------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0:
                if not os.path.isdir(config.GENERATED_DIR):
                    os.makedirs(config.GENERATED_DIR)

                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M}_x.png'), x_t)
                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M}_x_hat.png'),
                    x_hat)

            if (batch_num + 1) % 2000 == 0:
                save_model(E, device)
                save_model(G, device)
                save_model(D, device)

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, device, run_start)
        save_model(G, device, run_start)
        save_model(D, device, run_start)
        epoch_end = datetime.now()
        logging.info(
            f'Epoch {epoch+1} finished in {epoch_end - epoch_start}. '
            f'Average batch time: {sum(batch_durations, timedelta(0)) / len(batch_durations)}'
        )
コード例 #20
0
    y_cat = np.zeros((y.shape[0], num_columns))
    y_cat[range(y.shape[0]), y] = 1.

    return Variable(FloatTensor(y_cat))


# GAN Loss function
adversarial_loss = nn.MSELoss()
categorical_loss = nn.CrossEntropyLoss()
continuous_loss = nn.MSELoss()

# Initialize generator and discriminator
generator = network.Generator(latent_dim=latent_dim,
                              categorical_dim=n_classes,
                              continuous_dim=code_dim)
discriminator = network.Discriminator(categorical_dim=n_classes)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    categorical_loss.cuda()
    continuous_loss.cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=lr,
                               betas=(b1, b2))
# itertools.chain : iterable 객체를 연결, ex = itertools.chain([1, 2, 3], {'a', 'b', 'c'}) => ex = (1, 2, 3, 'a', 'b', 'c')
optimizer_info = torch.optim.Adam(itertools.chain(generator.parameters(),
コード例 #21
0
def meta_train(gpu, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    logging.info(f'Running on {"GPU" if gpu else "CPU"}.')

    # region DATASET----------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    raw_dataset = VoxCelebDataset(
        root=dataset_path,
        extension='.vid',
        shuffle_frames=True,
        subset_size=config.SUBSET_SIZE,
        transform=transforms.Compose([
            transforms.Resize(config.IMAGE_SIZE),
            transforms.CenterCrop(config.IMAGE_SIZE),
            transforms.ToTensor(),
        ])
    )
    dataset = DataLoader(raw_dataset, batch_size=config.BATCH_SIZE, shuffle=True)

    # endregion

    # region NETWORK ---------------------------------------------------------------------------------------------------

    E = network.Embedder(GPU['Embedder'])
    G = network.Generator(GPU['Generator'])
    D = network.Discriminator(len(raw_dataset), GPU['Discriminator'])
    criterion_E_G = network.LossEG(config.FEED_FORWARD, GPU['LossEG'])
    criterion_D = network.LossD(GPU['LossD'])

    optimizer_E_G = Adam(
        params=list(E.parameters()) + list(G.parameters()),
        lr=config.LEARNING_RATE_E_G
    )
    optimizer_D = Adam(
        params=D.parameters(),
        lr=config.LEARNING_RATE_D
    )

    if continue_id is not None:
        E = load_model(E, continue_id)
        G = load_model(G, continue_id)
        D = load_model(D, continue_id)

    # endregion

    # region TRAINING LOOP ---------------------------------------------------------------------------------------------
    logging.info(f'Epochs: {config.EPOCHS} Batches: {len(dataset)} Batch Size: {config.BATCH_SIZE}')

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):

            # region PROCESS BATCH -------------------------------------------------------------------------------------
            batch_start = datetime.now()

            # video [B, K+1, 2, C, W, H]

            # Put one frame aside (frame t)
            t = video[:, -1, ...]  # [B, 2, C, W, H]
            video = video[:, :-1, ...]  # [B, K, 2, C, W, H]
            dims = video.shape

            # Calculate average encoding vector for video
            e_in = .reshape(dims[0] * dims[1], dims[2], dims[3], dims[4], dims[5])  # [BxK, 2, C, W, H]
            x, y = e_in[:, 0, ...], e_in[:, 1, ...]
            e_vectors = E(x, y).reshape(dims[0], dims[1], -1)  # B, K, len(e)
            e_hat = e_vectors.mean(dim=1)
 
            # Generate frame using landmarks from frame t
            x_t, y_t = t[:, 0, ...], t[:, 1, ...]
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, _ = D(x_hat, y_t, i)
            r_x, _ = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat, D.W[:, i].transpose(1, 0))
            loss_D = criterion_D(r_x, r_x_hat)
            loss = loss_E_G + loss_D
            loss.backward()

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            x_hat = G(y_t, e_hat).detach()
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat)
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()

            # endregion

            # region SHOW PROGRESS -------------------------------------------------------------------------------------
            if (batch_num + 1) % 1 == 0 or batch_num == 0:
                logging.info(f'Epoch {epoch + 1}: [{batch_num + 1}/{len(dataset)}] | '
                             f'Time: {batch_end - batch_start} | '
                             f'Loss_E_G = {loss_E_G.item():.4f} Loss_D = {loss_D.item():.4f}')
                logging.debug(f'D(x) = {r_x.mean().item():.4f} D(x_hat) = {r_x_hat.mean().item():.4f}')
            # endregion

            # region SAVE ----------------------------------------------------------------------------------------------
            save_image(os.path.join(config.GENERATED_DIR, f'last_result_x.png'), x_t[0])
            save_image(os.path.join(config.GENERATED_DIR, f'last_result_x_hat.png'), x_hat[0])

            if (batch_num + 1) % 100 == 0:
                save_image(os.path.join(config.GENERATED_DIR, f'{datetime.now():%Y%m%d_%H%M%S%f}_x.png'), x_t[0])
                save_image(os.path.join(config.GENERATED_DIR, f'{datetime.now():%Y%m%d_%H%M%S%f}_x_hat.png'), x_hat[0])

            if (batch_num + 1) % 100 == 0:
                save_model(E, gpu, run_start)
                save_model(G, gpu, run_start)
                save_model(D, gpu, run_start)

            # endregion

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, gpu, run_start)
        save_model(G, gpu, run_start)
        save_model(D, gpu, run_start)
        epoch_end = datetime.now()
        logging.info(f'Epoch {epoch + 1} finished in {epoch_end - epoch_start}. ')
コード例 #22
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', '-g', type=int, default=0)
    parser.add_argument('--dir', type=str, default='./CelebA_Datasets/')
    parser.add_argument('--gen', type=str, default=None)
    parser.add_argument('--dis', type=str, default=None)
    parser.add_argument('--optg', type=str, default=None)
    parser.add_argument('--optd', type=str, default=None)
    parser.add_argument('--epoch', '-e', type=int, default=3)
    parser.add_argument('--lr', '-l', type=float, default=0.001)
    parser.add_argument('--beta1', type=float, default=0)
    parser.add_argument('--beta2', type=float, default=0.99)
    parser.add_argument('--batch', '-b', type=int, default=16)
    parser.add_argument('--depth', '-d', type=int, default=0)
    parser.add_argument('--alpha', type=float, default=0)
    parser.add_argument('--delta', type=float, default=0.00005)
    parser.add_argument('--out', '-o', type=str, default='img/')
    parser.add_argument('--num', '-n', type=int, default=10)
    parser.add_argument('--im_num', '-i', type=int, default=200000)
    args = parser.parse_args()

    print('==================================================')
    print('Depth: {}'.format(args.depth))
    print('Num Minibatch Size: {}'.format(args.batch))
    print('Num epoch: {}'.format(args.epoch))
    print('==================================================')

    gen = network.Generator(depth=args.depth)
    if args.gen is not None:
        print('loading generator model from ' + args.gen)
        serializers.load_npz(args.gen, gen)

    dis = network.Discriminator(depth=args.depth)
    if args.dis is not None:
        print('loading discriminator model from ' + args.dis)
        serializers.load_npz(args.dis, dis)

    if args.gpu >= 0:
        cuda.get_device_from_id(args.gpu).use()
        gen.to_gpu()
        dis.to_gpu()

    opt_g = optimizers.Adam(alpha=args.lr, beta1=args.beta1, beta2=args.beta2)
    opt_g.setup(gen)
    if args.optg is not None:
        print('loading generator optimizer from ' + args.optg)
        serializers.load_npz(args.optg, opt_g)

    opt_d = optimizers.Adam(alpha=args.lr, beta1=args.beta1, beta2=args.beta2)
    opt_d.setup(dis)
    if args.optd is not None:
        print('loading discriminator optimizer from ' + args.optd)
        serializers.load_npz(args.optd, opt_d)

    train = datasets.labelandID(args.dir, './../../GAN/text_file/out_attr.txt',
                                args.im_num, args.depth)
    train = chainer.datasets.LabeledImageDataset(pairs=train,
                                                 root='resized_images')

    train = chainermn.scatter_dataset(train, comm)
    train_iter = iterators.SerialIterator(train, batch_size=args.batch)

    updater = WganGpUpdater(alpha=args.alpha,
                            delta=args.delta,
                            models=(gen, dis),
                            iterator={'main': train_iter},
                            optimizer={
                                'gen': opt_g,
                                'dis': opt_d
                            },
                            device=args.gpu)

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out='results')

    if os.path.isdir(args.out):
        shutil.rmtree(args.out)
    os.makedirs(args.out)

    def output_image(gen, depth, out, num):
        @chainer.training.make_extension()
        def make_image(trainer):
            z = gen.z(num)

            attribute = [1, 0, 0, 0, 1]
            c = [cupy.asarray(attribute, cupy.float32) for i in range(num)]
            c = cupy.asarray(c, dtype=cupy.float32)
            x = gen(z, c, alpha=trainer.updater.alpha)
            x = chainer.cuda.to_cpu(x.data)

            for i in range(args.num):
                img = x[i].copy()
                filename = os.path.join(
                    out, '%d_%d.png' % (trainer.updater.epoch, i))
                utils.save_image(img, filename)

        return make_image

    trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'gen_loss', 'loss_d', 'loss_l', 'loss_dr', 'dis_loss',
            'alpha'
        ]))
    trainer.extend(extensions.snapshot_object(gen, 'gen'),
                   trigger=(10, 'epoch'))
    trainer.extend(extensions.snapshot_object(dis, 'dis'),
                   trigger=(10, 'epoch'))
    trainer.extend(extensions.snapshot_object(opt_g, 'opt_g'),
                   trigger=(10, 'epoch'))
    trainer.extend(extensions.snapshot_object(opt_d, 'opt_d'),
                   trigger=(10, 'epoch'))
    trainer.extend(output_image(gen, args.depth, args.out, args.num),
                   trigger=(1, 'epoch'))
    trainer.extend(extensions.ProgressBar(update_interval=1))

    trainer.run()

    modelname = './results/gen'
    print('saving generator model to ' + modelname)
    serializers.save_npz(modelname, gen)

    modelname = './results/dis'
    print('saving discriminator model to ' + modelname)
    serializers.save_npz(modelname, dis)

    optname = './results/opt_g'
    print('saving generator optimizer to ' + optname)
    serializers.save_npz(optname, opt_g)

    optname = './results/opt_d'
    print('saving generator optimizer to ' + optname)
    serializers.save_npz(optname, opt_d)
コード例 #23
0
ファイル: my_train.py プロジェクト: knok/chainer-PGGAN
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', '-g', type=int, default=-1)
    parser.add_argument('--dir', type=str, default='./train_images/')
    parser.add_argument('--gen', type=str, default=None)
    parser.add_argument('--dis', type=str, default=None)
    parser.add_argument('--optg', type=str, default=None)
    parser.add_argument('--optd', type=str, default=None)
    parser.add_argument('--epoch', '-e', type=int, default=3)
    parser.add_argument('--lr', '-l', type=float, default=0.001)
    parser.add_argument('--beta1', type=float, default=0)
    parser.add_argument('--beta2', type=float, default=0.99)
    parser.add_argument('--batch', '-b', type=int, default=16)
    parser.add_argument('--depth', '-d', type=int, default=0)
    parser.add_argument('--alpha', type=float, default=0)
    parser.add_argument('--delta', type=float, default=0.00005)
    parser.add_argument('--out', '-o', type=str, default='img/')
    parser.add_argument('--num', '-n', type=int, default=10)
    parser.add_argument('--dim', type=int, default=5)
    args = parser.parse_args()

    train = dataset.YuiDataset(directory=args.dir, depth=args.depth)
    train_iter = iterators.SerialIterator(train, batch_size=args.batch)

    gen = network.LimGenerator(depth=args.depth, dim=args.dim)
    if args.gen is not None:
        print('loading generator model from ' + args.gen)
        serializers.load_npz(args.gen, gen)

    dis = network.Discriminator(depth=args.depth)
    if args.dis is not None:
        print('loading discriminator model from ' + args.dis)
        serializers.load_npz(args.dis, dis)
        
    if args.gpu >= 0:
        cuda.get_device_from_id(0).use()
        gen.to_gpu()
        dis.to_gpu()

    opt_g = optimizers.Adam(alpha=args.lr, beta1=args.beta1, beta2=args.beta2)
    opt_g.setup(gen)
    if args.optg is not None:
        print('loading generator optimizer from ' + args.optg)
        serializers.load_npz(args.optg, opt_g)
    
    opt_d = optimizers.Adam(alpha=args.lr, beta1=args.beta1, beta2=args.beta2)
    opt_d.setup(dis)
    if args.optd is not None:
        print('loading discriminator optimizer from ' + args.optd)
        serializers.load_npz(args.optd, opt_d)


    updater = WganGpUpdater(alpha=args.alpha,
                            delta=args.delta,
                            models=(gen, dis),
                            iterator={'main': train_iter},
                            optimizer={'gen': opt_g, 'dis': opt_d},
                            device=args.gpu)

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out='results')

    if os.path.isdir(args.out):
        shutil.rmtree(args.out)
    os.makedirs(args.out)
    for i in range(args.num):
        img = train.get_example(i)
        filename = os.path.join(args.out, 'real_%d.png'%i)
        utils.save_image(img, filename)
    
    def output_image(gen, depth, out, num):
        @chainer.training.make_extension()
        def make_image(trainer):
            z = gen.z(num)
            x = gen(z, alpha=trainer.updater.alpha)
            x = chainer.cuda.to_cpu(x.data)

            for i in range(args.num):
                img = x[i].copy()
                filename = os.path.join(out, '%d_%d.png' % (trainer.updater.epoch, i))
                utils.save_image(img, filename)

        return make_image
            
    
    trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
    trainer.extend(extensions.PrintReport(['epoch', 'gen_loss', 'loss_d', 'loss_l', 'loss_dr', 'dis_loss', 'alpha']))
    trainer.extend(extensions.snapshot_object(gen, 'gen'), trigger=(10, 'epoch'))
    trainer.extend(extensions.snapshot_object(dis, 'dis'), trigger=(10, 'epoch'))
    trainer.extend(extensions.snapshot_object(opt_g, 'opt_g'), trigger=(10, 'epoch'))
    trainer.extend(extensions.snapshot_object(opt_d, 'opt_d'), trigger=(10, 'epoch'))
    trainer.extend(output_image(gen, args.depth, args.out, args.num), trigger=(1, 'epoch'))
    trainer.extend(extensions.ProgressBar(update_interval=1))    
    
    trainer.run()

    modelname = './results/gen'
    print( 'saving generator model to ' + modelname )
    serializers.save_npz(modelname, gen)

    modelname = './results/dis'
    print( 'saving discriminator model to ' + modelname )
    serializers.save_npz(modelname, dis)

    optname = './results/opt_g'
    print( 'saving generator optimizer to ' + optname )
    serializers.save_npz(optname, opt_g)

    optname = './results/opt_d'
    print( 'saving generator optimizer to ' + optname )
    serializers.save_npz(optname, opt_d)
コード例 #24
0
ファイル: trainer.py プロジェクト: Jihunlee326/Pytorch-GANs
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.Resize((img_size, img_size), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
                                         batch_size=batch_size,
                                         shuffle=True)

# GAN Loss function
adversarial_loss = nn.BCELoss()

# Initialize generator and discriminator
generator = network.Generator(latent_dim=latent_dim, img_shape=img_shape)
discriminator = network.Discriminator(img_shape=img_shape)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=lr,
                               betas=(b1, b2))

# gpu or cpu
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
コード例 #25
0
IfInitial = False
dp = data_processor.DataProcessor()
if IfInitial:
    dp.init_data()
    print('initial done!')

# get training data and test data
train_set_np = np.load("data/train_set.npy")
train_set_label_np = np.load("data/train_set_label.npy")

test_set_np = np.load("data/test_set.npy")
test_set_label_np = np.load("data/test_set_label.npy")

# network
G = nw.Generator().cuda()
D = nw.Discriminator().cuda()

G.weight_init(mean=0, std=0.01)
D.weight_init(mean=0, std=0.01)

# load train data
BatchSize = 32

train_set = torch.load('data/train_data_set.lib')
train_data = torch_data.DataLoader(
    train_set,
    batch_size=BatchSize,
    shuffle=True,
    num_workers=2,
)
コード例 #26
0
ファイル: train_gan.py プロジェクト: skeeet/style-gan-pytorch
def train(settings, output_root):
    # directories
    weights_root = output_root.joinpath("weights")
    weights_root.mkdir()

    # settings
    amp_handle = amp.init(settings["use_apex"])

    if settings["use_cuda"]:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    dtype = torch.float32
    test_device = torch.device("cuda:0")
    test_dtype = torch.float16

    loss_type = settings["loss"]

    z_dim = settings["network"]["z_dim"]

    # model
    label_size = len(settings["labels"])
    generator = network.Generator(settings["network"],
                                  label_size).to(device, dtype)
    discriminator = network.Discriminator(settings["network"],
                                          label_size).to(device, dtype)

    # long-term average
    gs = network.Generator(settings["network"], label_size).to(device, dtype)
    gs.load_state_dict(generator.state_dict())
    gs_beta = settings["gs_beta"]

    lt_learning_rate = settings["learning_rates"]["latent_transformation"]
    g_learning_rate = settings["learning_rates"]["generator"]
    d_learning_rate = settings["learning_rates"]["discriminator"]
    g_opt = optim.Adam([{
        "params": generator.latent_transform.parameters(),
        "lr": lt_learning_rate
    }, {
        "params": generator.synthesis_module.parameters()
    }],
                       lr=g_learning_rate,
                       betas=(0.0, 0.99),
                       eps=1e-8)
    d_opt = optim.Adam(discriminator.parameters(),
                       lr=d_learning_rate,
                       betas=(0.0, 0.99),
                       eps=1e-8)

    # train data
    loader = data_loader.LabeledDataLoader(settings)

    if settings["use_yuv"]:
        converter = image_converter.YUVConverter()
    else:
        converter = image_converter.RGBConverter()

    # parameters
    level = settings["start_level"]
    generator.set_level(level)
    discriminator.set_level(level)
    gs.set_level(level)
    fading = False
    alpha = 1
    step = 0

    # log
    writer = SummaryWriter(str(output_root))
    test_rows = 12
    test_cols = 6
    test_zs = utils.create_test_z(test_rows, test_cols, z_dim)
    test_z0 = torch.from_numpy(test_zs[0]).to(test_device, test_dtype)
    test_z1 = torch.from_numpy(test_zs[1]).to(test_device, test_dtype)
    test_labels0 = torch.randint(0, loader.label_size, (1, test_cols))
    test_labels0 = test_labels0.repeat(test_rows, 1).to(device)
    test_labels1 = torch.randint(0,
                                 loader.label_size, (test_rows, test_cols),
                                 device=test_device).view(-1)

    for loop in range(9999999):
        size = 2**(level + 1)

        batch_size = settings["batch_sizes"][level - 1]
        alpha_delta = batch_size / settings["num_images_in_stage"]

        image_count = 0

        for batch, labels in loader.generate(batch_size, size, size):
            # pre train
            step += 1
            image_count += batch_size
            if fading:
                alpha = min(1.0, alpha + alpha_delta)

            # data
            batch = batch.transpose([0, 3, 1, 2])
            batch = converter.to_train_data(batch)
            trues = torch.from_numpy(batch).to(device, dtype)
            labels = torch.from_numpy(labels).to(device)

            # reset
            g_opt.zero_grad()
            d_opt.zero_grad()

            # === train discriminator ===
            z = utils.create_z(batch_size, z_dim)
            z = torch.from_numpy(z).to(device, dtype)
            fakes = generator.forward(z, labels, alpha)
            fakes_nograd = fakes.detach()

            for param in discriminator.parameters():
                param.requires_grad_(True)
            if loss_type == "wgan":
                d_loss, wd = loss.d_wgan_loss(discriminator, trues,
                                              fakes_nograd, labels, alpha)
            elif loss_type == "lsgan":
                d_loss = loss.d_lsgan_loss(discriminator, trues, fakes_nograd,
                                           labels, alpha)
            elif loss_type == "logistic":
                d_loss = loss.d_logistic_loss(discriminator, trues,
                                              fakes_nograd, labels, alpha)
            else:
                raise Exception(f"Invalid loss: {loss_type}")

            with amp_handle.scale_loss(d_loss, d_opt) as scaled_loss:
                scaled_loss.backward()
            d_opt.step()

            # === train generator ===
            z = utils.create_z(batch_size, z_dim)
            z = torch.from_numpy(z).to(device, dtype)
            fakes = generator.forward(z, labels, alpha)

            for param in discriminator.parameters():
                param.requires_grad_(False)
            if loss_type == "wgan":
                g_loss = loss.g_wgan_loss(discriminator, fakes, labels, alpha)
            elif loss_type == "lsgan":
                g_loss = loss.g_lsgan_loss(discriminator, fakes, labels, alpha)
            elif loss_type == "logistic":
                g_loss = loss.g_logistic_loss(discriminator, fakes, labels,
                                              alpha)
            else:
                raise Exception(f"Invalid loss: {loss_type}")

            with amp_handle.scale_loss(g_loss, g_opt) as scaled_loss:
                scaled_loss.backward()
                del scaled_loss
            g_opt.step()

            del trues, fakes, fakes_nograd

            # update gs
            for gparam, gsparam in zip(generator.parameters(),
                                       gs.parameters()):
                gsparam.data = (1 -
                                gs_beta) * gsparam.data + gs_beta * gparam.data
            gs.w_average.data = (
                1 - gs_beta
            ) * gs.w_average.data + gs_beta * generator.w_average.data

            # log
            if step % 1 == 0:
                print(f"lv{level}-{step}: "
                      f"a: {alpha:.5f} "
                      f"g: {g_loss.item():.7f} "
                      f"d: {d_loss.item():.7f} ")

                writer.add_scalar(f"lv{level}/loss_gen",
                                  g_loss.item(),
                                  global_step=step)
                writer.add_scalar(f"lv{level}/loss_disc",
                                  d_loss.item(),
                                  global_step=step)
                if loss_type == "wgan":
                    writer.add_scalar(f"lv{level}/wd", wd, global_step=step)

            del d_loss, g_loss

            # histogram
            if settings["save_steps"]["histogram"] > 0 and step % settings[
                    "save_steps"]["histogram"] == 0:
                gs.write_histogram(writer, step)
                for name, param in discriminator.named_parameters():
                    writer.add_histogram(f"disc/{name}",
                                         param.cpu().data.numpy(), step)

            # image
            if step % settings["save_steps"]["image"] == 0 or alpha == 0:
                fading_text = "fading" if fading else "stabilizing"
                with torch.no_grad():
                    eval_gen = network.Generator(settings["network"],
                                                 label_size).to(
                                                     test_device,
                                                     test_dtype).eval()
                    eval_gen.load_state_dict(gs.state_dict())
                    eval_gen.synthesis_module.set_noise_fixed(True)
                    fakes = eval_gen.forward(test_z0, test_labels0, alpha)
                    fakes = torchvision.utils.make_grid(fakes,
                                                        nrow=test_cols,
                                                        padding=0)
                    fakes = fakes.to(torch.float32).cpu().numpy()
                    fakes = converter.from_generator_output(fakes)
                    writer.add_image(f"lv{level}_{fading_text}/intpl",
                                     torch.from_numpy(fakes), step)
                    fakes = eval_gen.forward(test_z1, test_labels1, alpha)
                    fakes = torchvision.utils.make_grid(fakes,
                                                        nrow=test_cols,
                                                        padding=0)
                    fakes = fakes.to(torch.float32).cpu().numpy()
                    fakes = converter.from_generator_output(fakes)
                    writer.add_image(f"lv{level}_{fading_text}/random",
                                     torch.from_numpy(fakes), step)
                    del eval_gen
                # memory usage
                writer.add_scalar("memory_allocated(MB)",
                                  torch.cuda.memory_allocated() /
                                  (1024 * 1024),
                                  global_step=step)

            # model save
            if step % settings["save_steps"][
                    "model"] == 0 and level >= 5 and not fading:
                savedir = weights_root.joinpath(f"{step}_lv{level}")
                savedir.mkdir()
                torch.save(generator.state_dict(), savedir.joinpath("gen.pth"))
                torch.save(generator.state_dict(), savedir.joinpath("gs.pth"))
                torch.save(discriminator.state_dict(),
                           savedir.joinpath("disc.pth"))

            # switch fading/stabilizing
            if image_count > settings["num_images_in_stage"]:
                if fading:
                    print("start stabilizing")
                    fading = False
                    alpha = 1
                    image_count = 0
                elif level < settings["max_level"]:
                    print(f"end lv: {level}")
                    break

        # level up
        if level < settings["max_level"]:
            level = level + 1
            generator.set_level(level)
            discriminator.set_level(level)
            gs.set_level(level)
            fading = True
            alpha = 0
            print(f"lv up: {level}")

            if settings["reset_optimizer"]:
                g_opt = optim.Adam(
                    [{
                        "params": generator.latent_transform.parameters(),
                        "lr": lt_learning_rate
                    }, {
                        "params": generator.synthesis_module.parameters()
                    }],
                    lr=g_learning_rate,
                    betas=(0.0, 0.99),
                    eps=1e-8)
                d_opt = optim.Adam(discriminator.parameters(),
                                   lr=d_learning_rate,
                                   betas=(0.0, 0.99),
                                   eps=1e-8)
コード例 #27
0
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
            torch.set_default_tensor_type("torch.cuda.FloatTensor")
        else:
            self.use_cuda = False
            torch.set_default_tensor_type("torch.FloatTensor")

        self.nz = config.nz
        self.optimizer = config.optimizer

        self.resl = 2  # we start from 2^2 = 4
        self.lr = config.lr
        self.eps_drift = config.eps_drift
        self.smoothing = config.smoothing
        self.max_resl = config.max_resl
        self.accelerate = 1
        self.wgan_target = 1.0
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.TICK = config.TICK
        self.skip = False
        self.globalIter = 0
        self.globalTick = 0
        self.wgan_epsilon = 0.001
        self.stack = 0
        self.wgan_lambda = 10.0
        self.just_passed = False
        if self.config.resume:
            saved_models = os.listdir("repo/model/")
            iterations = list(
                map(lambda x: int(x.split("_")[-1].split(".")[0][1:]),
                    saved_models))
            self.last_iteration = max(iterations)
            selected_indexes = np.where(
                [x == self.last_iteration for x in iterations])[0]
            G_last_model = [
                saved_models[x] for x in selected_indexes
                if "gen" in saved_models[x]
            ][0]
            D_last_model = [
                saved_models[x] for x in selected_indexes
                if "dis" in saved_models[x]
            ][0]
            saved_grids = os.listdir("repo/save/grid")
            global_iterations = list(
                map(lambda x: int(x.split("_")[0]), saved_grids))
            self.globalIter = self.config.save_img_every * max(
                global_iterations)
            print("Resuming after " + str(self.last_iteration) +
                  " ticks and " + str(self.globalIter) + " iterations")
            G_weights = torch.load("repo/model/" + G_last_model)
            D_weights = torch.load("repo/model/" + D_last_model)
            self.resuming = True
        else:
            self.resuming = False

        self.kimgs = 0
        self.stack = 0
        self.epoch = 0
        self.fadein = {"gen": None, "dis": None}
        self.complete = {"gen": 0, "dis": 0}
        self.phase = "init"
        self.flag_flush_gen = False
        self.flag_flush_dis = False
        self.flag_add_noise = self.config.flag_add_noise
        self.flag_add_drift = self.config.flag_add_drift

        # network and cirterion
        self.G = net.Generator(config)
        self.D = net.Discriminator(config)
        print("Generator structure: ")
        print(self.G.model)
        print("Discriminator structure: ")
        print(self.D.model)
        self.mse = torch.nn.MSELoss()
        if self.use_cuda:
            self.mse = self.mse.cuda()
            torch.cuda.manual_seed(config.random_seed)
            self.G = torch.nn.DataParallel(self.G,
                                           device_ids=[0]).cuda(device=0)
            self.D = torch.nn.DataParallel(self.D,
                                           device_ids=[0]).cuda(device=0)

        # define tensors, ship model to cuda, and get dataloader.
        self.renew_everything()
        if self.resuming:
            self.resl = G_weights["resl"]
            self.globalIter = G_weights["globalIter"]
            self.globalTick = G_weights["globalTick"]
            self.kimgs = G_weights["kimgs"]
            self.epoch = G_weights["epoch"]
            self.phase = G_weights["phase"]
            self.fadein = G_weights["fadein"]
            self.complete = G_weights["complete"]
            self.flag_flush_gen = G_weights["flag_flush_gen"]
            self.flag_flush_dis = G_weights["flag_flush_dis"]
            self.stack = G_weights["stack"]

            print("Resuming at " + str(self.resl) + " definition after " +
                  str(self.epoch) + " epochs")
            self.G.module.load_state_dict(G_weights["state_dict"])
            self.D.module.load_state_dict(D_weights["state_dict"])
            self.opt_g.load_state_dict(G_weights["optimizer"])
            self.opt_d.load_state_dict(D_weights["optimizer"])

        # tensorboard
        self.use_tb = config.use_tb
        if self.use_tb:
            self.tb = tensorboard.tf_recorder()
コード例 #28
0
def meta_train(gpu, dataset_path, continue_id):
    run_start = datetime.now()
    logging.info('===== META-TRAINING =====')
    # GPU / CPU --------------------------------------------------------------------------------------------------------
    if gpu:
        dtype = torch.cuda.FloatTensor
        torch.set_default_tensor_type(dtype)
        logging.info(f'Running on GPU: {torch.cuda.current_device()}.')
    else:
        dtype = torch.FloatTensor
        torch.set_default_tensor_type(dtype)
        logging.info(f'Running on CPU.')

    # DATASET-----------------------------------------------------------------------------------------------------------
    logging.info(f'Training using dataset located in {dataset_path}')
    raw_dataset = VoxCelebDataset(
        root=dataset_path,
        extension='.vid',
        shuffle_frames=True,
        # subset_size=1,
        transform=transforms.Compose([
            transforms.Resize(config.IMAGE_SIZE),
            transforms.CenterCrop(config.IMAGE_SIZE),
            transforms.ToTensor(),
        ]))
    dataset = DataLoader(raw_dataset,
                         batch_size=config.BATCH_SIZE,
                         shuffle=True)

    # NETWORK ----------------------------------------------------------------------------------------------------------

    E = network.Embedder().type(dtype)
    G = network.Generator().type(dtype)
    D = network.Discriminator(len(raw_dataset)).type(dtype)

    optimizer_E_G = Adam(params=list(E.parameters()) + list(G.parameters()),
                         lr=config.LEARNING_RATE_E_G)
    optimizer_D = Adam(params=D.parameters(), lr=config.LEARNING_RATE_D)

    criterion_E_G = network.LossEG(feed_forward=True)
    criterion_D = network.LossD()

    if gpu:
        E = DataParallel(E)
        G = DataParallel(G)
        D = ParallelDiscriminator(D)
        criterion_E_G = DataParallel(criterion_E_G)
        criterion_D = DataParallel(criterion_D)

    if continue_id is not None:
        E = load_model(E, 'Embedder', continue_id)
        G = load_model(G, 'Generator', continue_id)
        D = load_model(D, 'Discriminator', continue_id)

    # TRAINING LOOP ----------------------------------------------------------------------------------------------------
    logging.info(f'Starting training loop. '
                 f'Epochs: {config.EPOCHS} '
                 f'Batches: {len(dataset)} '
                 f'Batch Size: {config.BATCH_SIZE}')

    for epoch in range(config.EPOCHS):
        epoch_start = datetime.now()
        batch_durations = []

        E.train()
        G.train()
        D.train()

        for batch_num, (i, video) in enumerate(dataset):
            batch_start = datetime.now()
            video = video.type(dtype)  # [B, K+1, 2, C, W, H]

            # Put one frame aside (frame t)
            t = video[:, -1, ...]  # [B, 2, C, W, H]
            video = video[:, :-1, ...]  # [B, K, C, W, H]
            dims = video.shape

            # Calculate average encoding vector for video
            e_in = video.reshape(dims[0] * dims[1], dims[2], dims[3], dims[4],
                                 dims[5])  # [BxK, 2, C, W, H]
            x, y = e_in[:, 0, ...], e_in[:, 1, ...]
            e_vectors = E(x, y).reshape(dims[0], dims[1], -1)  # B, K, len(e)
            e_hat = e_vectors.mean(dim=1)

            # Generate frame using landmarks from frame t
            x_t, y_t = t[:, 0, ...], t[:, 1, ...]
            x_hat = G(y_t, e_hat)

            # Optimize E_G and D
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_E_G.zero_grad()
            optimizer_D.zero_grad()

            loss_E_G = criterion_E_G(x_t, x_hat, r_x_hat, e_hat,
                                     D.W[:, i].transpose(1, 0), D_act,
                                     D_act_hat).mean()
            loss_D = criterion_D(r_x, r_x_hat).mean()
            loss = loss_E_G + loss_D
            loss.backward()

            optimizer_E_G.step()
            optimizer_D.step()

            # Optimize D again
            x_hat = G(y_t, e_hat).detach()
            r_x_hat, D_act_hat = D(x_hat, y_t, i)
            r_x, D_act = D(x_t, y_t, i)

            optimizer_D.zero_grad()
            loss_D = criterion_D(r_x, r_x_hat).mean()
            loss_D.backward()
            optimizer_D.step()

            batch_end = datetime.now()
            batch_duration = batch_end - batch_start
            batch_durations.append(batch_duration)
            # SHOW PROGRESS --------------------------------------------------------------------------------------------
            if (batch_num + 1) % 1 == 0 or batch_num == 0:
                logging.info(
                    f'Epoch {epoch + 1}: [{batch_num + 1}/{len(dataset)}] | '
                    f'Time: {batch_duration} | '
                    f'Loss_E_G = {loss_E_G.item():.4} Loss_D {loss_D.item():.4}'
                )
                logging.debug(
                    f'D(x) = {r_x.mean().item():.4} D(x_hat) = {r_x_hat.mean().item():.4}'
                )

            # SAVE IMAGES ----------------------------------------------------------------------------------------------
            save_image(
                os.path.join(config.GENERATED_DIR, f'last_result_x.png'),
                x_t[0])
            save_image(
                os.path.join(config.GENERATED_DIR, f'last_result_x_hat.png'),
                x_hat[0])

            if (batch_num + 1) % 1000 == 0:
                save_image(
                    os.path.join(config.GENERATED_DIR,
                                 f'{datetime.now():%Y%m%d_%H%M%S%f}_x.png'),
                    x_t[0])
                save_image(
                    os.path.join(
                        config.GENERATED_DIR,
                        f'{datetime.now():%Y%m%d_%H%M%S%f}_x_hat.png'),
                    x_hat[0])

            # SAVE MODELS ----------------------------------------------------------------------------------------------
            if (batch_num + 1) % 100 == 0:
                save_model(E, 'Embedder', gpu, run_start)
                save_model(G, 'Generator', gpu, run_start)
                save_model(D, 'Discriminator', gpu, run_start)

        # SAVE MODELS --------------------------------------------------------------------------------------------------

        save_model(E, 'Embedder', gpu, run_start)
        save_model(G, 'Generator', gpu, run_start)
        save_model(D, 'Discriminator', gpu, run_start)
        epoch_end = datetime.now()
        logging.info(
            f'Epoch {epoch + 1} finished in {epoch_end - epoch_start}. '
            f'Average batch time: {sum(batch_durations, timedelta(0)) / len(batch_durations)}'
        )
コード例 #29
0
    def __init__(self, config):

        nvidia_smi.nvmlInit()
        self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)

        self.config = config
        if torch.cuda.is_available():
            self.use_cuda = True
        ngpu = config.n_gpu

        # prepare folders
        if not os.path.exists('./checkpoint_dir/' + config.model_name):
            os.mkdir('./checkpoint_dir/' + config.model_name)
        if not os.path.exists('./images/' + config.model_name):
            os.mkdir('./images/' + config.model_name)
        if not os.path.exists('./tb_log/' + config.model_name):
            os.mkdir('./tb_log/' + config.model_name)
        if not os.path.exists('./code_backup/' + config.model_name):
            os.mkdir('./code_backup/' + config.model_name)
        os.system('cp *.py ' + './code_backup/' + config.model_name)

        # network
        self.G = net.Generator(config).cuda()
        self.D = net.Discriminator(config).cuda()
        print('Generator structure: ')
        print(self.G.model)
        print('Discriminator structure: ')
        print(self.D.model)

        devices = [i for i in range(ngpu)]
        self.G = MyDataParallel(self.G, device_ids=devices)
        self.D = MyDataParallel(self.D, device_ids=devices)

        self.start_resl = config.start_resl
        self.max_resl = config.max_resl

        self.load_model(G_pth=config.G_pth, D_pth=config.D_pth)

        self.nz = config.nz
        self.optimizer = config.optimizer
        self.lr = config.lr

        self.fadein = {'gen': None, 'dis': None}
        self.upsam_mode = self.config.G_upsam_mode  # either 'nearest' or 'tri-linear'

        self.batchSize = {
            2: 64 * ngpu,
            3: 64 * ngpu,
            4: 64 * ngpu,
            5: 64 * ngpu,
            6: 48 * ngpu,
            7: 12 * ngpu
        }
        self.fadeInEpochs = {2: 0, 3: 1, 4: 1, 5: 2000, 6: 2000, 7: 2000}
        self.stableEpochs = {2: 0, 3: 0, 4: 3510, 5: 10100, 6: 10600, 7: 50000}
        self.ncritic = {2: 5, 3: 5, 4: 5, 5: 3, 6: 3, 7: 3}

        # size 16 need 5000-7000 enough
        # size 32 need 16000-30000 enough

        self.global_batch_done = 0

        # define all dataloaders into a dictionary
        self.dataloaders = {}
        for resl in range(self.start_resl, self.max_resl + 1):
            self.dataloaders[resl] = DataLoader(
                DL.Data(config.train_data_root + 'resl{}/'.format(2**resl)),
                batch_size=self.batchSize[resl],
                shuffle=True,
                drop_last=True)

        # ship new model to cuda, and update optimizer
        self.renew_everything()
コード例 #30
0
ファイル: main.py プロジェクト: avani17101/MLMME
def main(input_path, training_mode='mle-gan'):
    #Preprocessing data
    data_obj = pr.Preprocessing()
    data_obj.training_mode = training_mode
    data_obj.run(input_path)

    #Creating network object
    #Parameters
    input_size = len(data_obj.selected_columns)
    batch = data_obj.batch_size
    hidden_size = 200
    num_layers = 5
    num_directions = 1  # It should be 2 if we use bidirectional
    beam_width = [1, 3, 5, 7, 10, 15]  #Window size of beam search

    #Creating Networks
    enc = nw.Encoder(input_size, batch, hidden_size, num_layers,
                     num_directions).cuda()
    dec = nw.Decoder(input_size, batch, hidden_size, num_layers,
                     dropout=.3).cuda()
    dec.duration_time_loc = data_obj.duration_time_loc
    rnnD = nw.Discriminator(input_size,
                            batch,
                            hidden_size,
                            num_layers,
                            dropout=.3).cuda()
    model = nw.Seq2Seq(enc, dec).cuda()
    #Initializing model parameters
    model.apply(nw.init_weights)
    rnnD.apply(nw.init_weights)
    #Creating optimizers
    optimizerG = torch.optim.RMSprop(model.parameters(), lr=5e-5)
    optimizerD = torch.optim.RMSprop(rnnD.parameters(), lr=5e-5)
    #Lets try several GPU
    # if torch.cuda.device_count() > 1:
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     model = torch.nn.DataParallel(model, device_ids= range(0, torch.cuda.device_count()))
    #     enc = torch.nn.DataParallel(enc, device_ids= range(0, torch.cuda.device_count()))
    #     dec = torch.nn.DataParallel(dec, device_ids= range(0, torch.cuda.device_count()))
    #     rnnD = torch.nn.DataParallel(rnnD, device_ids= range(0, torch.cuda.device_count()))

    #--------------------------------------------
    if (training_mode == 'mle'):
        print("Training via MLE")
        nw.train_mle(model, optimizerG, data_obj)
        #Loading the best model saved during training
        path = os.path.join(data_obj.output_dir, 'rnnG(validation entropy).m')
        model.load_state_dict(torch.load(path))
        nw.model_eval_test(model, data_obj, mode='test')

    elif (training_mode == 'mle-gan'):
        print("Training via MLE-GAN")
        #Training via MLE-GAN
        nw.train_gan(model, rnnD, optimizerG, optimizerD, data_obj)
        # Loading the best model saved during training
        path = os.path.join(data_obj.output_dir,
                            'rnnG(validation entropy gan).m')
        model.load_state_dict(torch.load(path))
        nw.model_eval_test(model, data_obj, mode='test')
    #-------------------------------------------

    #Generating suffixes
    print("start generating suffixes using beam search!")
    for i in beam_width:
        sf.suffix_generate(model, data_obj, candidate_num=i)
        sf.suffix_similarity(data_obj, beam_size=i)

    return data_obj, model