Esempio n. 1
0
    train_pairs, test_pairs = get_engfra_split(
        split=experiment_arguments.split)

    if experiment_arguments.split == "add_book":
        new_prim_pair = ['book', 'livre']
    elif experiment_arguments.split == "add_house":
        new_prim_pair = ['house', 'maison']

    in_equivariances, out_equivariances = get_equivariances(
        experiment_arguments.equivariance)
    equivariant_eng, equivariant_fra = \
        get_equivariant_engfra_languages(pairs=train_pairs+test_pairs,
                                       input_equivariances=in_equivariances,
                                       output_equivariances=out_equivariances)

    input_symmetry_group = get_permutation_equivariance(equivariant_eng)
    output_symmetry_group = get_permutation_equivariance(equivariant_fra)

    # Initialize model
    model = EquiSeq2Seq(input_symmetry_group=input_symmetry_group,
                        output_symmetry_group=output_symmetry_group,
                        input_language=equivariant_eng,
                        encoder_hidden_size=experiment_arguments.hidden_size,
                        decoder_hidden_size=experiment_arguments.hidden_size,
                        output_language=equivariant_fra,
                        layer_type=experiment_arguments.layer_type,
                        use_attention=experiment_arguments.use_attention,
                        bidirectional=experiment_arguments.bidirectional)

    # Move model to device and load weights
    model.to(device)
Esempio n. 2
0
        get_equivariant_scan_languages(pairs=train_pairs,
                                       input_equivariances=in_equivariances,
                                       output_equivariances=out_equivariances)
    if experiment_arguments.equivariance == 'verb+direction':
        from perm_equivariant_seq2seq.symmetry_groups import VerbDirectionSCAN

        input_symmetry_group = VerbDirectionSCAN(
            num_letters=equivariant_commands.n_words,
            first_equivariant=equivariant_commands.num_fixed_words + 1
        )
        output_symmetry_group = VerbDirectionSCAN(
            num_letters=equivariant_actions.n_words,
            first_equivariant=equivariant_actions.num_fixed_words + 1
        )
    else:
        input_symmetry_group = get_permutation_equivariance(equivariant_commands)
        output_symmetry_group = get_permutation_equivariance(equivariant_actions)

    # Initialize model
    model = EquiSeq2Seq(input_symmetry_group=input_symmetry_group,
                        output_symmetry_group=output_symmetry_group,
                        input_language=equivariant_commands,
                        encoder_hidden_size=experiment_arguments.hidden_size,
                        decoder_hidden_size=experiment_arguments.hidden_size,
                        output_language=equivariant_actions,
                        layer_type=experiment_arguments.layer_type,
                        use_attention=experiment_arguments.use_attention,
                        bidirectional=experiment_arguments.bidirectional)

    # Move model to device and load weights
    model.to(device)
Esempio n. 3
0
    return loss.item()


if __name__ == '__main__':
    # Load data
    train_pairs, test_pairs = get_engfra_split(split=args.split)
    print("Got training and testing pairs")

    in_equivariances, out_equivariances = get_equivariances(args.equivariance)
    eng_lang, fra_lang = \
        get_equivariant_engfra_languages(pairs=train_pairs+test_pairs,
                                       input_equivariances=in_equivariances,
                                       output_equivariances=out_equivariances)
    print("making symmetry group")
    input_symmetry_group = get_permutation_equivariance(eng_lang)
    output_symmetry_group = get_permutation_equivariance(fra_lang)
    print("symmetry groups made")
    # Initialize model
    model = EquiSeq2Seq(input_symmetry_group=input_symmetry_group,
                        output_symmetry_group=output_symmetry_group,
                        input_language=eng_lang,
                        encoder_hidden_size=args.hidden_size,
                        decoder_hidden_size=args.hidden_size,
                        output_language=fra_lang,
                        layer_type=args.layer_type,
                        use_attention=args.use_attention,
                        bidirectional=args.bidirectional)
    print("model made")
    model.to(device)
    print("model to device")