Beispiel #1
0
def main(_):
    skip_layers = ['fc8']  # no pre-trained weights
    train_layers = ['fc8']
    finetune_layers = [
        'fc7', 'fc6', 'conv5', 'conv4', 'conv3', 'conv2', 'conv1'
    ]

    writer_dir = "logs/{}".format(FLAGS.dataset)
    checkpoint_dir = "checkpoints/{}".format(FLAGS.dataset)

    if tf.gfile.Exists(writer_dir):
        tf.gfile.DeleteRecursively(writer_dir)
    tf.gfile.MakeDirs(writer_dir)

    if tf.gfile.Exists(checkpoint_dir):
        tf.gfile.DeleteRecursively(checkpoint_dir)
    tf.gfile.MakeDirs(checkpoint_dir)

    generator = DataGenerator(data_dir=FLAGS.data_dir,
                              dataset=FLAGS.dataset,
                              batch_size=FLAGS.batch_size,
                              num_threads=FLAGS.num_threads)

    model = VS_CNN(num_classes=2, skip_layers=skip_layers)
    loss = loss_fn(model)
    warm_up, train_op, learning_rate = train_fn(loss, generator,
                                                finetune_layers, train_layers)
    saver = tf.train.Saver(max_to_keep=FLAGS.num_checkpoints)

    config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        logger = Logger(writer_dir, sess.graph)
        result_file = open('result_{}_base.txt'.format(FLAGS.dataset), 'w')

        sess.run(tf.global_variables_initializer())
        model.load_initial_weights(sess)

        print("{} Start training...".format(datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(
            datetime.now(), writer_dir))
        for epoch in range(FLAGS.num_epochs):
            print("\n{} Epoch: {}/{}".format(datetime.now(), epoch + 1,
                                             FLAGS.num_epochs))
            result_file.write("\n{} Epoch: {}/{}\n".format(
                datetime.now(), epoch + 1, FLAGS.num_epochs))

            if epoch < 20:
                update_op = warm_up
            else:
                update_op = train_op
            train(sess, model, generator, update_op, learning_rate, loss,
                  epoch, logger)

            test(sess, model, generator, result_file)
            # save_model(sess, saver, epoch, checkpoint_dir)

        result_file.close()
    def __init__(self, config):

        self.config = config

        # log directory
        if self.config.experiment_name=='':
            self.experiment_name = current_time
        else:
            self.experiment_name = self.config.experiment_name

        self.log_dir = config.log_dir+'/'+self.experiment_name
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # save config settings to file
        with open(self.log_dir+'/train_config.txt','w') as f:
            print('------------- training configuration -------------', file=f)
            for k, v in vars(config).items():
                print(('{}: {}').format(k, v), file=f)
            print(' ... loading training configuration ... ')
            print(' ... saving training configuration to {}'.format(f))

        self.train_data_root = self.config.data_root

        # training samples
        self.train_sample_dir = self.log_dir+'/samples_train'

        # checkpoints
        self.ckpt_dir = self.log_dir+'/ckpts'
        print('#######################')

        print("checkpoint dir:",self.ckpt_dir)
        print('#######################')

        # tensorboard
        if self.config.tb_logging:
            self.tb_dir = self.log_dir+'/tensorboard'

        self.use_cuda = use_cuda
        self.nz = config.nz
        self.nc = config.nc
        self.optimizer = config.optimizer
        self.batch_size_table = config.batch_size_table
        self.lr = config.lr
        self.d_eps_penalty = config.d_eps_penalty
        self.acgan = config.acgan
        self.max_resl = int(np.log2(config.max_resl))
        self.nframes_in = config.nframes_in
        self.nframes_pred = config.nframes_pred
        self.nframes = self.nframes_in+self.nframes_pred
        self.ext = config.ext
        self.nworkers = 4
        self.trns_tick = config.trns_tick
        self.stab_tick = config.stab_tick
        self.complete = 0.0
        self.x_add_noise = config.x_add_noise
        self.fadein = {'G':None, 'D':None}
        self.init_resl = 2
        self.init_img_size = int(pow(2, self.init_resl))

        # initialize model G as FutureGenerator from model.py
        self.G = model.FutureGenerator(config)

        # initialize model D as Discriminator from model.py
        self.D = model.Discriminator(config)

        # define losses
        if self.config.loss=='lsgan':
            self.criterion = torch.nn.MSELoss()

        elif self.config.loss=='gan':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()

        elif self.config.loss=='wgan_gp':
            if self.config.d_sigmoid==True:
                self.criterion = torch.nn.BCELoss()
            else:
                self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            raise Exception('Loss is undefined! Please set one of the following: `gan`, `lsgan` or `wgan_gp`')

        # check --use_ckpt
        # if --use_ckpt==False: build initial model
        # if --use_ckpt==True: load and build model from specified checkpoints
        if self.config.use_ckpt==False:

            print(' ... creating initial models ... ')

            # set initial model parameters
            self.resl = self.init_resl
            self.start_resl = self.init_resl
            self.globalIter = 0
            self.nsamples = 0
            self.stack = 0
            self.epoch = 0
            self.iter_start = 0
            self.phase = 'init'
            self.flag_flush = False

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()

            # count model parameters
            nparams_g = count_model_params(self.G)
            nparams_d = count_model_params(self.D)

            # save initial model structure to file
            with open(self.log_dir+'/initial_model_structure_{}x{}.txt'.format(self.init_img_size, self.init_img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), ', Batch size: ', self.batch_size, file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('Number of Generator`s model parameters: ', file=f)
                print(nparams_g, file=f)
                print('Number of Discriminator`s model parameters: ', file=f)
                print(nparams_d, file=f)
                print('--------------------------------------------------', file=f)
                print('Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... initial models have been built successfully ... ')
                print(' ... saving initial model strutures to {}'.format(f))

            # ship everything to cuda and parallelize for ngpu>1
            if self.use_cuda:
                self.criterion = self.criterion.cuda()
                torch.cuda.manual_seed(config.random_seed)
                if config.ngpu==1:
                    self.G = torch.nn.DataParallel(self.G).cuda(device=0)
                    self.D = torch.nn.DataParallel(self.D).cuda(device=0)
                else:
                    gpus = []
                    for i  in range(config.ngpu):
                        gpus.append(i)
                    self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda()
                    self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda()

        else:

            # re-ship everything to cuda
            if self.use_cuda:
                self.G = self.G.cuda()
                self.D = self.D.cuda()

            # load checkpoint
            print(' ... loading models from checkpoints ... {} and {}'.format(self.config.ckpt_path[0], self.config.ckpt_path[1]))
            self.ckpt_g = torch.load(self.config.ckpt_path[0])
            self.ckpt_d = torch.load(self.config.ckpt_path[1])

            # get model parameters
            self.resl = self.ckpt_g['resl']
            self.start_resl = int(self.ckpt_g['resl'])
            self.iter_start = self.ckpt_g['iter']+1
            self.globalIter = int(self.ckpt_g['globalIter'])
            self.stack = int(self.ckpt_g['stack'])
            self.nsamples = int(self.ckpt_g['nsamples'])
            self.epoch = int(self.ckpt_g['epoch'])
            self.fadein['G'] = self.ckpt_g['fadein']
            self.fadein['D'] = self.ckpt_d['fadein']
            self.phase = self.ckpt_d['phase']
            self.complete  = self.ckpt_d['complete']
            self.flag_flush = self.ckpt_d['flag_flush']
            img_size = int(pow(2, floor(self.resl)))

            # get model structure
            self.G = self.ckpt_g['G_structure']
            self.D = self.ckpt_d['D_structure']

            # define tensors, ship model to cuda, and get dataloader
            self.renew_everything()
            self.schedule_resl()
            self.nsamples = int(self.ckpt_g['nsamples'])

            # save loaded model structure to file
            with open(self.log_dir+'/resumed_model_structure_{}x{}.txt'.format(img_size, img_size),'w') as f:
                print('--------------------------------------------------', file=f)
                print('Sequences in Dataset: ', len(self.dataset), file=f)
                print('Global iteration step: ', self.globalIter, ', Epoch: ', self.epoch, file=f)
                print('Phase: ', self.phase, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Generator structure: ', file=f)
                print(self.G, file=f)
                print('--------------------------------------------------', file=f)
                print('Reloaded Discriminator structure: ', file=f)
                print(self.D, file=f)
                print('--------------------------------------------------', file=f)
                print(' ... models have been loaded successfully from checkpoints ... ')
                print(' ... saving resumed model strutures to {}'.format(f))

            # load model state_dict
            self.G.load_state_dict(self.ckpt_g['state_dict'])
            self.D.load_state_dict(self.ckpt_d['state_dict'])

            # load optimizer state dict
            lr = self.lr
            for i in range(1,int(floor(self.resl))-1):
                self.lr = lr*(self.config.lr_decay**i)
#            self.opt_g.load_state_dict(self.ckpt_g['optimizer'])
#            self.opt_d.load_state_dict(self.ckpt_d['optimizer'])
#            for param_group in self.opt_g.param_groups:
#                self.lr = param_group['lr']

        # tensorboard logging
        self.tb_logging = self.config.tb_logging
        if self.tb_logging==True:
            if not os.path.exists(self.tb_dir):
                os.makedirs(self.tb_dir)
            self.logger = Logger(self.tb_dir)