示例#1
0
    # Initialize parameters
    log.info("Initializing parameters")
    model.init_params()

    # Create theano shared variables
    log.info('Creating shared variables')
    model.init_shared_variables()

    # List of weights that will not receive updates during BP
    dont_update = []

    # Override some weights with pre-trained ones if given
    if train_args.init:
        log.info('Will override parameters from pre-trained weights')
        log.info('  %s' % os.path.basename(train_args.init))
        new_params = get_param_dict(train_args.init)
        model.update_shared_variables(new_params)
        if freeze:
            log.info('Pretrained weights will not be updated.')
            dont_update = list(new_params.keys())

    # Print number of parameters
    log.info("Number of parameters: %s" % model.get_nb_params())

    # Load data
    log.info("Loading data")
    model.load_data()

    # Dump model information
    model.info()
示例#2
0
    model.init_shared_variables()
    # Khoa:
    discriminator.init_shared_variables()
    if train_args.model_language_model_type is not None:
        language_model.init_shared_variables()
    # Khoa.

    # List of weights that will not receive updates during BP
    dont_update = []

    # Override some weights with pre-trained ones if given
    if train_args.init:
        log.info(
            'Will override parameters from pre-trained weights init Generator')
        log.info('  %s' % os.path.basename(train_args.init))
        new_params = get_param_dict(train_args.init)
        model.update_shared_variables(new_params)
        if freeze:
            log.info('Pretrained weights will not be updated.')
            dont_update = list(new_params.keys())

    if train_args.initdis:
        log.info(
            'Will override parameters from pre-trained weights init Discriminator'
        )
        log.info('  %s' % os.path.basename(train_args.initdis))
        new_params = get_param_dict(train_args.initdis)
        discriminator.update_shared_variables(new_params)
        if freeze:
            log.info('Pretrained weights will not be updated.')
            dont_update = list(new_params.keys())