Exemplo n.º 1
0
def get_data(opts):
    """
    creating all possible ordered pairs for given n_values.

    Splitting the pairs into:
    generalization_holdout ... all pairs with a zero, not including three pairs:
                               [(0,0), (0,1), (1,0)]
    uniform_holdout ... 10% of pairs without a zero (e.g. (42,13), (13,1), ...)
    train ... 90% of pairs without a zero plus three pairs with a zero
              (e.g. (0,0), (0,1), (1,0), (23,1), (2,43), ...)
    """
    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)
    assert opts.n_attributes == 2
    additional_training_pairs = [(0, 0), (0, 1), (1, 0)]
    train = additional_training_pairs + train
    for pair in additional_training_pairs[1:]:
        # (0 , 0) is not in generalization_holdout
        generalization_holdout.remove(pair)
    return full_data, train, uniform_holdout, generalization_holdout
Exemplo n.º 2
0
def main(params):
    import copy

    opts = get_params(params)
    device = opts.device

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(
            full_data, opts.density_data, opts.n_attributes, opts.n_values
        )
        full_data = copy.deepcopy(sampled_data)

    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)

    generalization_holdout, train, uniform_holdout, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
        for x in [generalization_holdout, train, uniform_holdout, full_data]
    ]

    train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(train, 1)

    generalization_holdout, uniform_holdout, full_data = (
        ScaledDataset(generalization_holdout),
        ScaledDataset(uniform_holdout),
        ScaledDataset(full_data),
    )
    generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [generalization_holdout, uniform_holdout, full_data]
    ]

    train_loader = DataLoader(train, batch_size=opts.batch_size)
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    if opts.receiver_cell in ["lstm", "rnn", "gru"]:
        receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size + 1,
            opts.receiver_emb,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
    else:
        raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")

    if opts.sender_cell in ["lstm", "rnn", "gru"]:
        sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        sender = core.RnnSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_emb,
            hidden_size=opts.sender_hidden,
            max_len=opts.max_len,
            cell=opts.sender_cell,
        )
    else:
        raise ValueError(f"Unknown sender cell, {opts.sender_cell}")

    sender = PlusOneWrapper(sender)
    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        "no": core.baselines.NoBaseline,
        "mean": core.baselines.MeanBaseline,
        "builtin": core.baselines.BuiltInBaseline,
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline,
    )
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluator = Metrics(
        validation.examples,
        opts.device,
        opts.n_attributes,
        opts.n_values,
        opts.vocab_size + 1,
        freq=opts.stats_freq,
    )

    loaders = []
    loaders.append(
        (
            "generalization hold out",
            generalization_holdout_loader,
            DiffLoss(opts.n_attributes, opts.n_values, generalization=True),
        )
    )
    loaders.append(
        (
            "uniform holdout",
            uniform_holdout_loader,
            DiffLoss(opts.n_attributes, opts.n_values),
        )
    )

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True)

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=validation_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=False),
            early_stopper,
            metrics_evaluator,
            holdout_evaluator,
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)

    last_epoch_interaction = early_stopper.validation_stats[-1][1]
    validation_acc = last_epoch_interaction.aux["acc"].mean()

    uniformtest_acc = holdout_evaluator.results["uniform holdout"]["acc"]

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random

            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(
                sender,
                receiver,
                loss,
                sender_entropy_coeff=0.0,
                receiver_entropy_coeff=0.0,
            )
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(
                opts.early_stopping_thr, validation=True
            )

            trainer = core.Trainer(
                game=game,
                optimizer=optimizer,
                train_data=train_loader,
                validation_data=validation_loader,
                callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)],
            )
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]["acc"] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=opts.receiver_hidden,
                cell="gru",
            )

        def small_gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=100, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=100,
                cell="gru",
            )

        def tiny_gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=50,
                cell="gru",
            )

        def nonlinear_receiver_generator():
            return NonLinearReceiver(
                n_outputs=n_dim,
                vocab_size=opts.vocab_size + 1,
                max_length=opts.max_len,
                n_hidden=opts.receiver_hidden,
            )

        for name, receiver_generator in [
            ("gru", gru_receiver_generator),
            ("nonlinear", nonlinear_receiver_generator),
            ("tiny_gru", tiny_gru_receiver_generator),
            ("small_gru", small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps(
                        {
                            "mode": "reset",
                            "seed": seed,
                            "receiver_name": name,
                            "auc": auc,
                        }
                    )
                )

    print("---End--")

    core.close()
Exemplo n.º 3
0
def main(params):
    import copy
    opts = get_params(params)
    device = opts.device

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(full_data, opts.density_data,
                                        opts.n_attributes, opts.n_values)
        full_data = copy.deepcopy(sampled_data)

    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)

    generalization_holdout, train, uniform_holdout, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
        for x in [generalization_holdout, train, uniform_holdout, full_data]
    ]

    train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(
        train, 1)

    generalization_holdout, uniform_holdout, full_data = ScaledDataset(
        generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(
            full_data)
    generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [generalization_holdout, uniform_holdout, full_data]
    ]

    train_loader = DataLoader(train, batch_size=opts.batch_size)
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    if opts.receiver_cell in ['lstm', 'rnn', 'gru']:
        receiver = MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.MMRnnReceiverDeterministic(receiver,
                                                   opts.vocab_size + 1,
                                                   opts.receiver_emb,
                                                   opts.receiver_hidden,
                                                   cell=opts.receiver_cell)
    else:
        raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}')

    if opts.sender_cell in ['lstm', 'rnn', 'gru']:
        sender = MMSender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        s1 = SplitWrapper(sender, 0)
        s2 = SplitWrapper(sender, 1)
        sender1 = core.RnnSenderReinforce(agent=s1,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        sender1 = PlusNWrapper(sender1, 1)
        sender2 = core.RnnSenderReinforce(agent=s2,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        # sender2 = PlusNWrapper(sender2, opts.vocab_size + 1)
        sender2 = PlusNWrapper(sender2, 1)
        sender = CombineMMRnnSenderReinforce(sender1, sender2)
    else:
        raise ValueError(f'Unknown sender cell, {opts.sender_cell}')

    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        'no': core.baselines.NoBaseline,
        'mean': core.baselines.MeanBaseline,
        'builtin': core.baselines.BuiltInBaseline
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline)
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluator = MMMetrics(validation.examples,
                                  opts.device,
                                  opts.n_attributes,
                                  opts.n_values,
                                  opts.vocab_size + 1,
                                  freq=opts.stats_freq)

    loaders = []
    loaders.append(("generalization hold out", generalization_holdout_loader,
                    DiffLoss(opts.n_attributes,
                             opts.n_values,
                             generalization=True)))
    loaders.append(("uniform holdout", uniform_holdout_loader,
                    DiffLoss(opts.n_attributes, opts.n_values)))

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                         validation=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=False),
                               early_stopper, metrics_evaluator,
                               holdout_evaluator
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    validation_acc = early_stopper.validation_stats[-1][1]['acc']
    uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc']

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random
            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(sender,
                                                   receiver,
                                                   loss,
                                                   sender_entropy_coeff=0.0,
                                                   receiver_entropy_coeff=0.0)
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                                 validation=True)

            trainer = core.Trainer(game=game,
                                   optimizer=optimizer,
                                   train_data=train_loader,
                                   validation_data=validation_loader,
                                   callbacks=[
                                       early_stopper,
                                       Evaluator(loaders, opts.device, freq=0)
                                   ])
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]['acc'] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru')

        def small_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru')

        def tiny_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=25, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=25, cell='gru')

        def nonlinear_receiver_generator():            return \
MMNonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1,
            max_length=opts.max_len, n_hidden=opts.receiver_hidden)

        for name, receiver_generator in [
            ('gru', gru_receiver_generator),
            ('nonlinear', nonlinear_receiver_generator),
            ('tiny_gru', tiny_gru_receiver_generator),
            ('small_gru', small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps({
                        "mode": "reset",
                        "seed": seed,
                        "receiver_name": name,
                        "auc": auc
                    }))

    print('---End--')

    core.close()