# Initialize parameters log.info("Initializing parameters") model.init_params() # Create theano shared variables log.info('Creating shared variables') model.init_shared_variables() # List of weights that will not receive updates during BP dont_update = [] # Override some weights with pre-trained ones if given if train_args.init: log.info('Will override parameters from pre-trained weights') log.info(' %s' % os.path.basename(train_args.init)) new_params = get_param_dict(train_args.init) model.update_shared_variables(new_params) if freeze: log.info('Pretrained weights will not be updated.') dont_update = list(new_params.keys()) # Print number of parameters log.info("Number of parameters: %s" % model.get_nb_params()) # Load data log.info("Loading data") model.load_data() # Dump model information model.info()
model.init_shared_variables() # Khoa: discriminator.init_shared_variables() if train_args.model_language_model_type is not None: language_model.init_shared_variables() # Khoa. # List of weights that will not receive updates during BP dont_update = [] # Override some weights with pre-trained ones if given if train_args.init: log.info( 'Will override parameters from pre-trained weights init Generator') log.info(' %s' % os.path.basename(train_args.init)) new_params = get_param_dict(train_args.init) model.update_shared_variables(new_params) if freeze: log.info('Pretrained weights will not be updated.') dont_update = list(new_params.keys()) if train_args.initdis: log.info( 'Will override parameters from pre-trained weights init Discriminator' ) log.info(' %s' % os.path.basename(train_args.initdis)) new_params = get_param_dict(train_args.initdis) discriminator.update_shared_variables(new_params) if freeze: log.info('Pretrained weights will not be updated.') dont_update = list(new_params.keys())