예제 #1
0
def main(_):
    skip_layers = ['fc8']  # no pre-trained weights
    train_layers = ['fc8']
    finetune_layers = [
        'fc7', 'fc6', 'conv5', 'conv4', 'conv3', 'conv2', 'conv1'
    ]

    writer_dir = "logs/{}".format(FLAGS.dataset)
    checkpoint_dir = "checkpoints/{}".format(FLAGS.dataset)

    if tf.gfile.Exists(writer_dir):
        tf.gfile.DeleteRecursively(writer_dir)
    tf.gfile.MakeDirs(writer_dir)

    if tf.gfile.Exists(checkpoint_dir):
        tf.gfile.DeleteRecursively(checkpoint_dir)
    tf.gfile.MakeDirs(checkpoint_dir)

    generator = DataGenerator(data_dir=FLAGS.data_dir,
                              dataset=FLAGS.dataset,
                              batch_size=FLAGS.batch_size,
                              num_threads=FLAGS.num_threads)

    model = VS_CNN(num_classes=2, skip_layers=skip_layers)
    loss = loss_fn(model)
    warm_up, train_op, learning_rate = train_fn(loss, generator,
                                                finetune_layers, train_layers)
    saver = tf.train.Saver(max_to_keep=FLAGS.num_checkpoints)

    config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        logger = Logger(writer_dir, sess.graph)
        result_file = open('result_{}_base.txt'.format(FLAGS.dataset), 'w')

        sess.run(tf.global_variables_initializer())
        model.load_initial_weights(sess)

        print("{} Start training...".format(datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(
            datetime.now(), writer_dir))
        for epoch in range(FLAGS.num_epochs):
            print("\n{} Epoch: {}/{}".format(datetime.now(), epoch + 1,
                                             FLAGS.num_epochs))
            result_file.write("\n{} Epoch: {}/{}\n".format(
                datetime.now(), epoch + 1, FLAGS.num_epochs))

            if epoch < 20:
                update_op = warm_up
            else:
                update_op = train_op
            train(sess, model, generator, update_op, learning_rate, loss,
                  epoch, logger)

            test(sess, model, generator, result_file)
            # save_model(sess, saver, epoch, checkpoint_dir)

        result_file.close()
예제 #2
0
    def __init__(self, config):

        self.config = config

        # log directory
        if self.config.experiment_name=='':
            self.experiment_name = current_time
        else:
            self.experiment_name = self.config.experiment_name

        self.log_dir = config.log_dir+'/'+self.experiment_name
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # save config settings to file
        with open(self.log_dir+'/train_config.txt','w') as f:
            print('------------- training configuration -------------', file=f)
            for k, v in vars(config).items():
                print(('{}: {}').format(k, v), file=f)
            print(' ... loading training configuration ... ')
            print(' ... saving training configuration to {}'.format(f))

        self.train_data_root = self.config.data_root

        # training samples
        self.train_sample_dir = self.log_dir+'/samples_train'

        # checkpoints
        self.ckpt_dir = self.log_dir+'/ckpts'
        print('#######################')

        print("checkpoint dir:",self.ckpt_dir)
        print('#######################')

        # tensorboard
        if self.config.tb_logging:
            self.tb_dir = self.log_dir+'/tensorboard'

        self.use_cuda = use_cuda
        self.nz = config.nz
        self.nc = config.nc
        self.optimizer = config.optimizer
        self.batch_size_table = config.batch_size_table
        self.lr = config.lr
        self.d_eps_penalty = config.d_eps_penalty
        self.acgan = config.acgan
        self.max_resl = int(np.log2(config.max_resl))
        self.nframes_in = config.nframes_in
        self.nframes_pred = config.nframes_pred
        self.nframes = self.nframes_in+self.nframes_pred
        self.ext = config.ext
        self.nworkers = 4
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.complete = 0.0
        self.x_add_noise = config.x_add_noise
        self.fadein = {'G':None, 'D':None}
        self.init_resl = 2
        self.init_img_size = int(pow(2, self.init_resl))

        # initialize model G as FutureGenerator from model.py
        self.G = model.FutureGenerator(config)

        # initialize model D as Discriminator from model.py
        self.D = model.Discriminator(config)

        # define losses
        if self.config.loss=='lsgan':
            self.criterion = torch.nn.MSELoss()

        elif self.config.loss=='gan':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()

        elif self.config.loss=='wgan_gp':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            raise Exception('Loss is undefined! Please set one of the following: `gan`, `lsgan` or `wgan_gp`')

        # check --use_ckpt
        # if --use_ckpt==False: build initial model
        # if --use_ckpt==True: load and build model from specified checkpoints
        if self.config.use_ckpt==False:

            print(' ... creating initial models ... ')

            # set initial model parameters
            self.resl = self.init_resl
            self.start_resl = self.init_resl
            self.globalIter = 0
            self.nsamples = 0
            self.stack = 0
            self.epoch = 0
            self.iter_start = 0
            self.phase = 'init'
            self.flag_flush = False

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()

            # count model parameters
            nparams_g = count_model_params(self.G)
            nparams_d = count_model_params(self.D)

            # save initial model structure to file
            with open(self.log_dir+'/initial_model_structure_{}x{}.txt'.format(self.init_img_size, self.init_img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), ', Batch size: ', self.batch_size, file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('Number of Generator`s model parameters: ', file=f)
                print(nparams_g, file=f)
                print('Number of Discriminator`s model parameters: ', file=f)
                print(nparams_d, file=f)
                print('--------------------------------------------------', file=f)
                print('Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... initial models have been built successfully ... ')
                print(' ... saving initial model strutures to {}'.format(f))

            # ship everything to cuda and parallelize for ngpu>1
            if self.use_cuda:
                self.criterion = self.criterion.cuda()
                torch.cuda.manual_seed(config.random_seed)
                if config.ngpu==1:
                    self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                    self.D = torch.nn.DataParallel(self.D).cuda(device=0)
                else:
                    gpus = []
                    for i  in range(config.ngpu):
                        gpus.append(i)
                    self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                    self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()

        else:

            # re-ship everything to cuda
            if self.use_cuda:
                self.G = self.G.cuda()
                self.D = self.D.cuda()

            # load checkpoint
            print(' ... loading models from checkpoints ... {} and {}'.format(self.config.ckpt_path[0], self.config.ckpt_path[1]))
            self.ckpt_g = torch.load(self.config.ckpt_path[0])
            self.ckpt_d = torch.load(self.config.ckpt_path[1])

            # get model parameters
            self.resl = self.ckpt_g['resl']
            self.start_resl = int(self.ckpt_g['resl'])
            self.iter_start = self.ckpt_g['iter']+1
            self.globalIter = int(self.ckpt_g['globalIter'])
            self.stack = int(self.ckpt_g['stack'])
            self.nsamples = int(self.ckpt_g['nsamples'])
            self.epoch = int(self.ckpt_g['epoch'])
            self.fadein['G'] = self.ckpt_g['fadein']
            self.fadein['D'] = self.ckpt_d['fadein']
            self.phase = self.ckpt_d['phase']
            self.complete  = self.ckpt_d['complete']
            self.flag_flush = self.ckpt_d['flag_flush']
            img_size = int(pow(2, floor(self.resl)))

            # get model structure
            self.G = self.ckpt_g['G_structure']
            self.D = self.ckpt_d['D_structure']

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()
            self.schedule_resl()
            self.nsamples = int(self.ckpt_g['nsamples'])

            # save loaded model structure to file
            with open(self.log_dir+'/resumed_model_structure_{}x{}.txt'.format(img_size, img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... models have been loaded successfully from checkpoints ... ')
                print(' ... saving resumed model strutures to {}'.format(f))

            # load model state_dict
            self.G.load_state_dict(self.ckpt_g['state_dict'])
            self.D.load_state_dict(self.ckpt_d['state_dict'])

            # load optimizer state dict
            lr = self.lr
            for i in range(1,int(floor(self.resl))-1):
                self.lr = lr*(self.config.lr_decay**i)
#            self.opt_g.load_state_dict(self.ckpt_g['optimizer'])
#            self.opt_d.load_state_dict(self.ckpt_d['optimizer'])
#            for param_group in self.opt_g.param_groups:
#                self.lr = param_group['lr']

        # tensorboard logging
        self.tb_logging = self.config.tb_logging
        if self.tb_logging==True:
            if not os.path.exists(self.tb_dir):
                os.makedirs(self.tb_dir)
            self.logger = Logger(self.tb_dir)
예제 #3
0
class Trainer:
    '''
    Class to train a FutureGAN model.

    Data is assumed to be arranged in this way:
        data_root/video/frame.ext -> dataset/train/video1/frame1.ext
                                                  -> dataset/train/video1/frame2.ext
                                                  -> dataset/train/video2/frame1.ext
                                                  -> ...
    '''

    def __init__(self, config):

        self.config = config

        # log directory
        if self.config.experiment_name=='':
            self.experiment_name = current_time
        else:
            self.experiment_name = self.config.experiment_name

        self.log_dir = config.log_dir+'/'+self.experiment_name
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # save config settings to file
        with open(self.log_dir+'/train_config.txt','w') as f:
            print('------------- training configuration -------------', file=f)
            for k, v in vars(config).items():
                print(('{}: {}').format(k, v), file=f)
            print(' ... loading training configuration ... ')
            print(' ... saving training configuration to {}'.format(f))

        self.train_data_root = self.config.data_root

        # training samples
        self.train_sample_dir = self.log_dir+'/samples_train'

        # checkpoints
        self.ckpt_dir = self.log_dir+'/ckpts'
        print('#######################')

        print("checkpoint dir:",self.ckpt_dir)
        print('#######################')

        # tensorboard
        if self.config.tb_logging:
            self.tb_dir = self.log_dir+'/tensorboard'

        self.use_cuda = use_cuda
        self.nz = config.nz
        self.nc = config.nc
        self.optimizer = config.optimizer
        self.batch_size_table = config.batch_size_table
        self.lr = config.lr
        self.d_eps_penalty = config.d_eps_penalty
        self.acgan = config.acgan
        self.max_resl = int(np.log2(config.max_resl))
        self.nframes_in = config.nframes_in
        self.nframes_pred = config.nframes_pred
        self.nframes = self.nframes_in+self.nframes_pred
        self.ext = config.ext
        self.nworkers = 4
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.complete = 0.0
        self.x_add_noise = config.x_add_noise
        self.fadein = {'G':None, 'D':None}
        self.init_resl = 2
        self.init_img_size = int(pow(2, self.init_resl))

        # initialize model G as FutureGenerator from model.py
        self.G = model.FutureGenerator(config)

        # initialize model D as Discriminator from model.py
        self.D = model.Discriminator(config)

        # define losses
        if self.config.loss=='lsgan':
            self.criterion = torch.nn.MSELoss()

        elif self.config.loss=='gan':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()

        elif self.config.loss=='wgan_gp':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            raise Exception('Loss is undefined! Please set one of the following: `gan`, `lsgan` or `wgan_gp`')

        # check --use_ckpt
        # if --use_ckpt==False: build initial model
        # if --use_ckpt==True: load and build model from specified checkpoints
        if self.config.use_ckpt==False:

            print(' ... creating initial models ... ')

            # set initial model parameters
            self.resl = self.init_resl
            self.start_resl = self.init_resl
            self.globalIter = 0
            self.nsamples = 0
            self.stack = 0
            self.epoch = 0
            self.iter_start = 0
            self.phase = 'init'
            self.flag_flush = False

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()

            # count model parameters
            nparams_g = count_model_params(self.G)
            nparams_d = count_model_params(self.D)

            # save initial model structure to file
            with open(self.log_dir+'/initial_model_structure_{}x{}.txt'.format(self.init_img_size, self.init_img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), ', Batch size: ', self.batch_size, file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('Number of Generator`s model parameters: ', file=f)
                print(nparams_g, file=f)
                print('Number of Discriminator`s model parameters: ', file=f)
                print(nparams_d, file=f)
                print('--------------------------------------------------', file=f)
                print('Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... initial models have been built successfully ... ')
                print(' ... saving initial model strutures to {}'.format(f))

            # ship everything to cuda and parallelize for ngpu>1
            if self.use_cuda:
                self.criterion = self.criterion.cuda()
                torch.cuda.manual_seed(config.random_seed)
                if config.ngpu==1:
                    self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                    self.D = torch.nn.DataParallel(self.D).cuda(device=0)
                else:
                    gpus = []
                    for i  in range(config.ngpu):
                        gpus.append(i)
                    self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                    self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()

        else:

            # re-ship everything to cuda
            if self.use_cuda:
                self.G = self.G.cuda()
                self.D = self.D.cuda()

            # load checkpoint
            print(' ... loading models from checkpoints ... {} and {}'.format(self.config.ckpt_path[0], self.config.ckpt_path[1]))
            self.ckpt_g = torch.load(self.config.ckpt_path[0])
            self.ckpt_d = torch.load(self.config.ckpt_path[1])

            # get model parameters
            self.resl = self.ckpt_g['resl']
            self.start_resl = int(self.ckpt_g['resl'])
            self.iter_start = self.ckpt_g['iter']+1
            self.globalIter = int(self.ckpt_g['globalIter'])
            self.stack = int(self.ckpt_g['stack'])
            self.nsamples = int(self.ckpt_g['nsamples'])
            self.epoch = int(self.ckpt_g['epoch'])
            self.fadein['G'] = self.ckpt_g['fadein']
            self.fadein['D'] = self.ckpt_d['fadein']
            self.phase = self.ckpt_d['phase']
            self.complete  = self.ckpt_d['complete']
            self.flag_flush = self.ckpt_d['flag_flush']
            img_size = int(pow(2, floor(self.resl)))

            # get model structure
            self.G = self.ckpt_g['G_structure']
            self.D = self.ckpt_d['D_structure']

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()
            self.schedule_resl()
            self.nsamples = int(self.ckpt_g['nsamples'])

            # save loaded model structure to file
            with open(self.log_dir+'/resumed_model_structure_{}x{}.txt'.format(img_size, img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... models have been loaded successfully from checkpoints ... ')
                print(' ... saving resumed model strutures to {}'.format(f))

            # load model state_dict
            self.G.load_state_dict(self.ckpt_g['state_dict'])
            self.D.load_state_dict(self.ckpt_d['state_dict'])

            # load optimizer state dict
            lr = self.lr
            for i in range(1,int(floor(self.resl))-1):
                self.lr = lr*(self.config.lr_decay**i)
#            self.opt_g.load_state_dict(self.ckpt_g['optimizer'])
#            self.opt_d.load_state_dict(self.ckpt_d['optimizer'])
#            for param_group in self.opt_g.param_groups:
#                self.lr = param_group['lr']

        # tensorboard logging
        self.tb_logging = self.config.tb_logging
        if self.tb_logging==True:
            if not os.path.exists(self.tb_dir):
                os.makedirs(self.tb_dir)
            self.logger = Logger(self.tb_dir)


    def schedule_resl(self):

        # trns and stab if resl > 2
        if floor(self.resl)!=2:
            self.trns_tick = self.config.trns_tick
            self.stab_tick = self.config.stab_tick

        # alpha and delta parameters for smooth fade-in (resl-interpolation)
        delta = 1.0/(self.trns_tick+self.stab_tick)
        d_alpha = 1.0*self.batch_size/self.trns_tick/len(self.dataset)

        # update alpha if FadeInLayer exist
        if self.fadein['D'] is not None:
            if self.resl%1.0 < (self.trns_tick)*delta:
                self.fadein['G'][0].update_alpha(d_alpha)
                self.fadein['G'][1].update_alpha(d_alpha)
                self.fadein['D'].update_alpha(d_alpha)
                self.complete = self.fadein['D'].alpha*100
                self.phase = 'trns'
            elif self.resl%1.0 >= (self.trns_tick)*delta and self.phase != 'final':
                self.phase = 'stab'

        # increase resl linearly every tick
        prev_nsamples = self.nsamples
        self.nsamples = self.nsamples + self.batch_size
        if (self.nsamples%len(self.dataset)) < (prev_nsamples%len(self.dataset)):
            self.nsamples = 0

            prev_resl = floor(self.resl)
            self.resl = self.resl + delta
            self.resl = max(2, min(10.5, self.resl))        # clamping, range: 4 ~ 1024

            # flush network.
            if self.flag_flush and self.resl%1.0 >= (self.trns_tick)*delta and prev_resl!=2:
                if self.fadein['D'] is not None:
                    self.fadein['G'][0].update_alpha(d_alpha)
                    self.fadein['G'][1].update_alpha(d_alpha)
                    self.fadein['D'].update_alpha(d_alpha)
                    self.complete = self.fadein['D'].alpha*100
                self.flag_flush = False
                self.G.module.flush_network()   # flush G
                self.D.module.flush_network()   # flush and,
                self.fadein['G'] = None
                self.fadein['D'] = None
                self.complete = 0.0
                if floor(self.resl) < self.max_resl and self.phase != 'final':
                    self.phase = 'stab'
                self.print_model_structure()

            # grow network.
            if floor(self.resl) != prev_resl and floor(self.resl)<self.max_resl+1:
                self.lr = self.lr * float(self.config.lr_decay)
                self.G.module.grow_network(floor(self.resl))
                self.D.module.grow_network(floor(self.resl))
                self.renew_everything()
                self.fadein['G'] = [self.G.module.model.fadein_block_decode, self.G.module.model.fadein_block_encode]
                self.fadein['D'] = self.D.module.model.fadein_block
                self.flag_flush = True
                self.print_model_structure()

            if floor(self.resl) >= self.max_resl and self.resl%1.0 >= self.trns_tick*delta:
                self.phase = 'final'
                self.resl = self.max_resl+self.trns_tick*delta


    def print_model_structure(self):

        img_size = self.img_size

        # count model parameters
        nparams_g = count_model_params(self.G)
        nparams_d = count_model_params(self.D)

        with open(self.log_dir+'/model_structure_{}x{}.txt'.format(img_size, img_size),'a') as f:
            print('--------------------------------------------------', file=f)
            print('Sequences in Dataset: ', len(self.dataset), file=f)
            print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
            print('Phase: ', self.phase, file=f)
            print('Number of Generator`s model parameters: ', file=f)
            print(nparams_g, file=f)
            print('Number of Discriminator`s model parameters: ', file=f)
            print(nparams_d, file=f)
            print('--------------------------------------------------', file=f)
            print('New Generator structure: ', file=f)
            print(self.G.module, file=f)
            print('--------------------------------------------------', file=f)
            print('New Discriminator structure: ', file=f)
            print(self.D.module, file=f)
            print('--------------------------------------------------', file=f)
            print(' ... models are being updated ... ')
            print(' ... saving updated model strutures to {}'.format(f))


    def renew_everything(self):

        # renew dataloader
        self.img_size = int(pow(2,min(floor(self.resl), self.max_resl)))
        self.batch_size = int(self.batch_size_table[pow(2,min(floor(self.resl), self.max_resl))])
        self.video_loader = video_loader
        self.transform_video = transforms.Compose([transforms.Resize(size=(self.img_size,self.img_size), interpolation=Image.NEAREST), transforms.ToTensor(),])
        self.dataset = VideoFolder(video_root=self.train_data_root, video_ext=self.ext, nframes=self.nframes, loader=self.video_loader, transform=self.transform_video)
        self.dataloader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.nworkers)
        self.epoch_tick = int(ceil(len(self.dataset)/self.batch_size))

        # define tensors
        self.real_label = Variable(torch.FloatTensor(self.batch_size, 1).fill_(1))
        self.fake_label = Variable(torch.FloatTensor(self.batch_size, 1).fill_(0))

        # wrapping autograd Variable.
        self.z = Variable(torch.FloatTensor(self.batch_size, self.nc, self.nframes_in, self.img_size, self.img_size))
        self.x = Variable(torch.FloatTensor(self.batch_size, self.nc, self.nframes, self.img_size, self.img_size))
        self.x_gen = Variable(torch.FloatTensor(self.batch_size, self.nc, self.nframes_pred, self.img_size, self.img_size))
        self.z_x_gen = Variable(torch.FloatTensor(self.batch_size, self.nc, self.nframes, self.img_size, self.img_size))

        # enable cuda
        if self.use_cuda:
            self.z = self.z.cuda()
            self.x = self.x.cuda()
            self.x_gen = self.x_gen.cuda()
            self.z_x_gen = self.z_x_gen.cuda()
            self.real_label = self.real_label.cuda()
            self.fake_label = self.fake_label.cuda()
            torch.cuda.manual_seed(config.random_seed)

        # ship new model to cuda.
        if self.use_cuda:
            self.G = self.G.cuda()
            self.D = self.D.cuda()

        # optimizer
        betas = (self.config.beta1, self.config.beta2)
        if self.optimizer == 'adam':
            self.opt_g = Adam(filter(lambda p: p.requires_grad, self.G.parameters()), lr=self.lr, betas=betas, weight_decay=0.0)
            self.opt_d = Adam(filter(lambda p: p.requires_grad, self.D.parameters()), lr=self.lr, betas=betas, weight_decay=0.0)


    def feed_interpolated_input(self, x):

        # interpolate input to match network resolution
        if self.phase == 'Gtrns' and floor(self.resl)>2 and floor(self.resl)<=self.max_resl:
            alpha = self.complete/100.0
            transform = transforms.Compose( [   transforms.ToPILImage(),
                                                transforms.Resize(size=int(pow(2,floor(self.resl)-1)), interpolation=0),      # 0: nearest
                                                transforms.Resize(size=int(pow(2,floor(self.resl))), interpolation=0),      # 0: nearest
                                                transforms.ToTensor(),
                                            ] )

            x_low = x.clone().add(1).mul(0.5)
            for i in range(x_low.size(0)):
                for j in range(x_low.size(2)):
                    x_low[i,:,j,:,:] = transform(x_low[i,:,j,:,:]).mul(2).add(-1)
            x = torch.add(x.mul(alpha), x_low.mul(1-alpha))

        if self.use_cuda:
            return x.cuda()
        else:
            return x


    def add_noise(self, x):

        if self.x_add_noise==False:
            return x

        # add noise to variable
        if hasattr(self, '_d_'):
            self._d_ = self._d_ * 0.9 + torch.mean(self.x_gen_label).data[0] * 0.1
        else:
            self._d_ = 0.0
        strength = 0.2 * max(0, self._d_ - 0.5)**2
        z = np.random.randn(*x.size()).astype(np.float32) * strength
        z = Variable(torch.from_numpy(z)).cuda() if self.use_cuda else Variable(torch.from_numpy(z))
        return x + z


    def get_batch(self):

        dataIter = iter(self.dataloader)
        return next(dataIter)


    def train(self):

        # train loop
        for step in range(self.start_resl, self.max_resl+2):

            for iter in tqdm(range(self.iter_start,(self.trns_tick+self.stab_tick)*int(ceil(len(self.dataset)/self.batch_size)))):
		
                self.iter = iter
                self.globalIter = self.globalIter+1
                self.stack = self.stack + self.batch_size
                #self.save_ckpt(self.ckpt_dir)
                if self.stack > ceil(len(self.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = 0

                   # save ckpt
                    if self.epoch%self.config.save_ckpt_every==0:
                        self.save_ckpt(self.ckpt_dir)

                # schedule resolution and update parameters
                self.schedule_resl()

                # zero gradients
                self.G.zero_grad()
                self.D.zero_grad()

                # interpolate discriminator real input
                self.x.data = self.feed_interpolated_input(self.get_batch())

                # if 'x_add_noise' --> input to generator without noise, input to discriminator with noise
                self.z.data = self.x.data[:,:,:self.nframes_in,:,:]
                if self.x_add_noise:
                    self.x = self.add_noise(self.x)
                if self.config.d_cond:
                    self.z_x_gen = self.G(self.z)
                    self.x_gen.data = self.z_x_gen.data[:,:,self.nframes_in:,:,:]
                    self.x_label = self.D(self.x.detach())
                    self.x_gen_label = self.D(self.z_x_gen.detach())
                else:
                    self.x_gen = self.G(self.z)
                    self.z_x_gen.data[:,:,:self.nframes_in,:,:] = self.z.data
                    self.z_x_gen.data[:,:,self.nframes_in:,:,:] = self.x_gen.data
                    self.x_label = self.D(self.x[:,:,self.nframes_in:,:,:].detach())
                    self.x_gen_label = self.D(self.x_gen.detach())

                # mse loss
                if self.config.loss=='lsgan':
                    loss_d = self.criterion(self.x_label, self.real_label) + self.criterion(self.x_gen_label, self.fake_label)

                # cross entropy with logits loss
                elif self.config.loss=='gan':
                    loss_d = self.criterion(self.x_label, self.real_label) + self.criterion(self.x_gen_label, self.fake_label)

                # wgan-gp loss
                elif self.config.loss=='wgan_gp':
                    loss_d = torch.mean(self.x_gen_label)-torch.mean(self.x_label)

                    # gradient penalty
                    lam = 10
                    alpha = torch.rand(self.batch_size, 1)
                    if self.config.d_cond==False:
                        alpha = alpha.expand(self.batch_size, self.x[:,:,self.nframes_in:,:,:][0].nelement()).contiguous().view(self.batch_size, self.x.size(1), self.x[:,:,self.nframes_in:,:,:].size(2), self.x.size(3), self.x.size(4))
                    else:
                        alpha = alpha.expand(self.batch_size, self.x[0].nelement()).contiguous().view(self.batch_size, self.x.size(1), self.x.size(2), self.x.size(3), self.x.size(4))
                    if self.use_cuda:
                        alpha = alpha.cuda()
                    if self.config.d_cond:
                        interpolates = alpha*self.x.data+((1-alpha)*self.z_x_gen.data)
                    else:
                        interpolates = alpha*self.x[:,:,self.nframes_in:,:,:].data+((1-alpha)*self.x_gen.data)
                    if self.use_cuda:
                        interpolates = interpolates.cuda()
                    interpolates = Variable(interpolates, requires_grad=True)
                    interpolates_label = self.D(interpolates)
                    gradients = torch.autograd.grad(outputs=interpolates_label, inputs=interpolates,
                              grad_outputs=torch.ones(interpolates_label.size()).cuda() if self.use_cuda else torch.ones(interpolates_label.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
#                    gradients = torch.autograd.grad(outputs=interpolates_label.sum().cuda() if self.use_cuda else interpolates_label.sum(), inputs=interpolates, create_graph=True)[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_penalty = ((gradients.norm(2, dim=1)-1)**2).mean()
                    loss_d = loss_d+lam*gradient_penalty

                    # epsilon penalty
                    if self.d_eps_penalty==True:
                        eps = 0.001
                        eps_penalty = torch.mean((self.x_label-0)**2)
                        loss_d = loss_d+eps_penalty*eps

                    # label penalty
                    # !!! makes GAN conditioned on classification labels of dataset
                    # only makes sense, if actual labels are given, which is not the case here.
                    if self.acgan==True:
                        cond_weight_d = 1.0
                        label_penalty_d = self.criterion(self.x_gen_label, self.fake_label)+self.criterion(self.x_label, self.real_label)
                        loss_d = loss_d+label_penalty_d*cond_weight_d

                # update discriminator
                loss_d.backward()
                self.opt_d.step()

                # get discriminator output
                if self.config.d_cond:
                    self.x_gen_label = self.D(self.z_x_gen)
                else:
                    self.x_gen_label = self.D(self.x_gen)

                # mse loss
                if self.config.loss=='lsgan':
                    loss_g = self.criterion(self.x_gen_label, self.real_label.detach())

                # cross entropy with logits loss
                elif self.config.loss=='gan':
                    loss_g = self.criterion(self.x_gen_label, self.real_label.detach())

                # wgan loss
                elif self.config.loss=='wgan_gp':
                    loss_g = -torch.mean(self.x_gen_label)

                    # label penalty
                    if self.acgan==True:
                        cond_weight_g = 1.0
                        label_penalty_g = self.criterion(self.x_gen_label, self.fake_label)
                        loss_g = loss_g+label_penalty_g*cond_weight_g

                # update generator
                loss_g.backward()
                self.opt_g.step()

                # set max. nr of samples for saving video grid logs
                if self.batch_size >= 8:
                    k = 8
                else:
                    k = self.batch_size

                # save video grid logs
                if self.globalIter%self.config.save_img_every==0 or self.globalIter==1:

                    # log x, z_x_gen
                    if not os.path.exists(self.train_sample_dir):
                        os.makedirs(self.train_sample_dir)

                    # save video grid: x, z_x_gen images
                    save_video_grid(self.x.data[:k,:,:,:,:], self.train_sample_dir+'/'+'x_E{}_I{}_R{}x{}_{}_G{}_D{}.jpg'.format(int(self.epoch), int(self.globalIter), int(self.img_size), int(self.img_size), self.phase, self.complete, self.complete))
                    save_video_grid(self.z_x_gen.data[:k,:,:,:,:], self.train_sample_dir+'/'+'z_x_gen_E{}_I{}_R{}x{}_{}_G{}_D{}.jpg'.format(int(self.epoch), int(self.globalIter), int(self.img_size), int(self.img_size), self.phase, self.complete, self.complete))

                # save tensorboard logs
                if self.tb_logging==True:

                    if self.globalIter%self.config.update_tb_every==0 or self.globalIter==1:

                        # log loss_g and loss_d
                        self.logger.log_scalar('loss/G', loss_g.item(), self.globalIter)
                        self.logger.log_scalar('loss/D', loss_d.item(), self.globalIter)

                        # log resl, lr and epoch
                        self.logger.log_scalar('tick/resl', int(pow(2,floor(self.resl))), self.globalIter)
                        self.logger.log_scalar('tick/lr', self.lr, self.globalIter)
                        self.logger.log_scalar('tick/epoch', self.epoch, self.globalIter)

                    # log model parameter histograms weight, bias, weight.grad and bias.grad
                    if self.globalIter%(self.config.update_tb_every*10)==0 or self.globalIter==1:
                        for tag, value in self.G.named_parameters():
                            tag = tag.replace('.', '/')
                            self.logger.log_histogram('G/'+tag, self.var2np(value), self.globalIter)
                            if value.grad is not None:
                                self.logger.log_histogram('G/'+tag+'/grad', self.var2np(value.grad), self.globalIter)
                        for tag, value in self.D.named_parameters():
                            tag = tag.replace('.', '/')
                            self.logger.log_histogram('D/'+tag, self.var2np(value), self.globalIter)
                            if value.grad is not None:
                                self.logger.log_histogram('D/'+tag+'/grad', self.var2np(value.grad), self.globalIter)

                # if iter>2:
                #     break
                    
            self.iter_start=0

            # save final model
        self.save_final_model(self.log_dir)


    def get_state(self, target):

        # ship models to cpu
        self.G_save = self.G.cpu()
        self.D_save = self.D.cpu()

        if target == 'G':
            state = {
                'G_structure': self.G_save,
                'globalIter': self.globalIter,
                'nsamples': self.nsamples,
                'stack': self.stack,
                'epoch': self.epoch,
                'resl' : self.resl,
                'iter': self.iter,
                'state_dict' : self.G_save.state_dict(),
                'optimizer' : self.opt_g.state_dict(),
                'fadein'    : self.fadein['G'],
                'phase'     : self.phase,
                'complete': self.complete,
                'flag_flush': self.flag_flush,
            }
            return state

        elif target == 'D':
            state = {
                'D_structure': self.D_save,
                'globalIter': self.globalIter,
                'nsamples': self.nsamples,
                'stack': self.stack,
                'epoch': self.epoch,
                'resl' : self.resl,
                'iter': self.iter,
                'state_dict' : self.D_save.state_dict(),
                'optimizer' : self.opt_d.state_dict(),
                'fadein'    : self.fadein['D'],
                'phase'     : self.phase,
                'complete': self.complete,
                'flag_flush': self.flag_flush,
            }
            return state


    def save_ckpt(self, path):

        if not os.path.exists(path):
            os.makedirs(path)
        ndis = 'dis_E{}_I{}_R{}x{}_{}.pth.tar'.format(self.epoch, self.globalIter, self.img_size, self.img_size, self.phase)
        ngen = 'gen_E{}_I{}_R{}x{}_{}.pth.tar'.format(self.epoch, self.globalIter, self.img_size, self.img_size, self.phase)
        save_path = os.path.join(path, ndis)
        if not os.path.exists(save_path):
            # ship models to cpu in get_state and save models
            torch.save(self.get_state('D'), save_path)
            save_path = os.path.join(path, ngen)
            torch.save(self.get_state('G'), save_path)
            print(' ... saving model checkpoints to {}'.format(path))

        # re-ship everything to cuda
        if self.use_cuda:
            self.G = self.G.cuda()
            self.D = self.D.cuda()


    def get_final_state(self, target):

        # ship models to cpu
        self.G_save = self.G.cpu()
        self.D_save = self.D.cpu()

        if target == 'G':
            state = {
                'G_structure': self.G_save,
                'resl' : self.resl,
                'state_dict' : self.G_save.state_dict(),
            }
            return state

        elif target == 'D':
            state = {
                'D_structure': self.D_save,
                'resl' : self.resl,
                'state_dict' : self.D_save.state_dict(),
            }
            return state


    def save_final_model(self, path):

        if not os.path.exists(path):
            os.makedirs(path)
        ndis = 'dis_E{}_I{}_R{}x{}_final.pth.tar'.format(self.epoch, self.globalIter, self.img_size, self.img_size)
        ngen = 'gen_E{}_I{}_R{}x{}_final.pth.tar'.format(self.epoch, self.globalIter, self.img_size, self.img_size)
        save_path = os.path.join(path, ndis)
        if not os.path.exists(save_path):
            # ship models to cpu in get_state and save models
            torch.save(self.get_final_state('D'), save_path)
            save_path = os.path.join(path, ngen)
            torch.save(self.get_final_state('G'), save_path)
            print(' ... saving final models to {}'.format(path))


    def var2np(self, var):
        if self.use_cuda:
            return var.cpu().data.numpy()
        return var.data.numpy()