示例#1
0
    def __init__(self, training_log=None) -> None:
        super().__init__()

        self.training_log = training_log if training_log is not None else core.get_opts(
        ).training_log_path

        self.train_loader, self.test_loader = \
            get_symbolic_dataloader(
                n_attributes=self.n_attributes,
                n_values=self.n_values,
                batch_size=self.batch_size
            )

        self.sender = core.RnnSenderGS(SymbolicSenderMLP(
            input_dim=self.n_attributes * self.n_values,
            hidden_dim=self.hidden_size),
                                       self.vocab_size,
                                       self.emb_size,
                                       self.hidden_size,
                                       max_len=self.max_len,
                                       cell="lstm",
                                       temperature=1.0)

        self.receiver = core.RnnReceiverGS(SymbolicReceiverMLP(
            self.n_attributes * self.n_values, self.hidden_size),
                                           self.vocab_size,
                                           self.emb_size,
                                           self.hidden_size,
                                           cell="lstm")

        self.game = core.SenderReceiverRnnGS(self.sender, self.receiver,
                                             self.loss)

        self.optimiser = core.build_optimizer(self.game.parameters())
        self.callbacks = []
        self.callbacks.append(
            ConsoleFileLogger(as_json=True,
                              print_train_loss=True,
                              logfile_path=self.training_log))
        self.callbacks.append(
            core.TemperatureUpdater(agent=self.sender, decay=0.9, minimum=0.1))
        self.callbacks.append(
            TopographicSimilarityLatents('hamming',
                                         'edit',
                                         log_path=core.get_opts().topo_path))
        self.trainer = core.Trainer(game=self.game,
                                    optimizer=self.optimiser,
                                    train_data=self.train_loader,
                                    validation_data=self.test_loader,
                                    callbacks=self.callbacks)
示例#2
0
文件: train.py 项目: wedddy0707/EGG
def main(params):
    import copy

    opts = get_params(params)
    device = opts.device

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(
            full_data, opts.density_data, opts.n_attributes, opts.n_values
        )
        full_data = copy.deepcopy(sampled_data)

    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)

    generalization_holdout, train, uniform_holdout, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
        for x in [generalization_holdout, train, uniform_holdout, full_data]
    ]

    train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(train, 1)

    generalization_holdout, uniform_holdout, full_data = (
        ScaledDataset(generalization_holdout),
        ScaledDataset(uniform_holdout),
        ScaledDataset(full_data),
    )
    generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [generalization_holdout, uniform_holdout, full_data]
    ]

    train_loader = DataLoader(train, batch_size=opts.batch_size)
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    if opts.receiver_cell in ["lstm", "rnn", "gru"]:
        receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.RnnReceiverDeterministic(
            receiver,
            opts.vocab_size + 1,
            opts.receiver_emb,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )
    else:
        raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")

    if opts.sender_cell in ["lstm", "rnn", "gru"]:
        sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        sender = core.RnnSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_emb,
            hidden_size=opts.sender_hidden,
            max_len=opts.max_len,
            cell=opts.sender_cell,
        )
    else:
        raise ValueError(f"Unknown sender cell, {opts.sender_cell}")

    sender = PlusOneWrapper(sender)
    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        "no": core.baselines.NoBaseline,
        "mean": core.baselines.MeanBaseline,
        "builtin": core.baselines.BuiltInBaseline,
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline,
    )
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluator = Metrics(
        validation.examples,
        opts.device,
        opts.n_attributes,
        opts.n_values,
        opts.vocab_size + 1,
        freq=opts.stats_freq,
    )

    loaders = []
    loaders.append(
        (
            "generalization hold out",
            generalization_holdout_loader,
            DiffLoss(opts.n_attributes, opts.n_values, generalization=True),
        )
    )
    loaders.append(
        (
            "uniform holdout",
            uniform_holdout_loader,
            DiffLoss(opts.n_attributes, opts.n_values),
        )
    )

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True)

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=validation_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True, print_train_loss=False),
            early_stopper,
            metrics_evaluator,
            holdout_evaluator,
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)

    last_epoch_interaction = early_stopper.validation_stats[-1][1]
    validation_acc = last_epoch_interaction.aux["acc"].mean()

    uniformtest_acc = holdout_evaluator.results["uniform holdout"]["acc"]

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random

            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(
                sender,
                receiver,
                loss,
                sender_entropy_coeff=0.0,
                receiver_entropy_coeff=0.0,
            )
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(
                opts.early_stopping_thr, validation=True
            )

            trainer = core.Trainer(
                game=game,
                optimizer=optimizer,
                train_data=train_loader,
                validation_data=validation_loader,
                callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)],
            )
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]["acc"] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=opts.receiver_hidden,
                cell="gru",
            )

        def small_gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=100, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=100,
                cell="gru",
            )

        def tiny_gru_receiver_generator():
            return core.RnnReceiverDeterministic(
                Receiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1,
                opts.receiver_emb,
                hidden_size=50,
                cell="gru",
            )

        def nonlinear_receiver_generator():
            return NonLinearReceiver(
                n_outputs=n_dim,
                vocab_size=opts.vocab_size + 1,
                max_length=opts.max_len,
                n_hidden=opts.receiver_hidden,
            )

        for name, receiver_generator in [
            ("gru", gru_receiver_generator),
            ("nonlinear", nonlinear_receiver_generator),
            ("tiny_gru", tiny_gru_receiver_generator),
            ("small_gru", small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps(
                        {
                            "mode": "reset",
                            "seed": seed,
                            "receiver_name": name,
                            "auc": auc,
                        }
                    )
                )

    print("---End--")

    core.close()
        target = move_to(target, recon_game.trainer.device)
        label = move_to(label, recon_game.trainer.device)

        msg, _ = ref_game.sender(target)
        rec_out = recon_game.receiver(msg.detach())
        loss, _ = recon_game.loss(target, msg, msg, rec_out, label)
        acc_list.append(loss.mean().item())
        loss.sum().backward()
        optimizer.step()
    print('train loss:', np.mean(acc_list))
    train_loss.append(np.mean(acc_list))

    acc_list = []
    for batch_idx, (target, label) in enumerate(recon_game.test_loader):
        recon_game.game.eval()
        target = move_to(target, recon_game.trainer.device)
        label = move_to(label, recon_game.trainer.device)
        msg, _ = ref_game.sender(target)
        rec_out = recon_game.receiver(msg)
        loss, _ = recon_game.loss(target, msg, msg, rec_out, label)
        acc_list.append(loss.mean().item())
    print('test loss:', np.mean(acc_list))
    test_loss.append(np.mean(acc_list))

file_name = core.get_opts().generalisation_path
create_dir_for_file(file_name)

with open(file_name, 'w') as f:
    for i in range(len(train_loss)):
        print(str(train_loss[i]) + ',' + str(test_loss[i]), file=f)
示例#4
0
def main(params):
    import copy
    opts = get_params(params)
    device = opts.device

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(full_data, opts.density_data,
                                        opts.n_attributes, opts.n_values)
        full_data = copy.deepcopy(sampled_data)

    train, generalization_holdout = split_holdout(full_data)
    train, uniform_holdout = split_train_test(train, 0.1)

    generalization_holdout, train, uniform_holdout, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
        for x in [generalization_holdout, train, uniform_holdout, full_data]
    ]

    train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(
        train, 1)

    generalization_holdout, uniform_holdout, full_data = ScaledDataset(
        generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(
            full_data)
    generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [generalization_holdout, uniform_holdout, full_data]
    ]

    train_loader = DataLoader(train, batch_size=opts.batch_size)
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    if opts.receiver_cell in ['lstm', 'rnn', 'gru']:
        receiver = MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.MMRnnReceiverDeterministic(receiver,
                                                   opts.vocab_size + 1,
                                                   opts.receiver_emb,
                                                   opts.receiver_hidden,
                                                   cell=opts.receiver_cell)
    else:
        raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}')

    if opts.sender_cell in ['lstm', 'rnn', 'gru']:
        sender = MMSender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        s1 = SplitWrapper(sender, 0)
        s2 = SplitWrapper(sender, 1)
        sender1 = core.RnnSenderReinforce(agent=s1,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        sender1 = PlusNWrapper(sender1, 1)
        sender2 = core.RnnSenderReinforce(agent=s2,
                                          vocab_size=opts.vocab_size,
                                          embed_dim=opts.sender_emb,
                                          hidden_size=opts.sender_hidden,
                                          max_len=opts.max_len,
                                          force_eos=False,
                                          cell=opts.sender_cell)
        # sender2 = PlusNWrapper(sender2, opts.vocab_size + 1)
        sender2 = PlusNWrapper(sender2, 1)
        sender = CombineMMRnnSenderReinforce(sender1, sender2)
    else:
        raise ValueError(f'Unknown sender cell, {opts.sender_cell}')

    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        'no': core.baselines.NoBaseline,
        'mean': core.baselines.MeanBaseline,
        'builtin': core.baselines.BuiltInBaseline
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline)
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluator = MMMetrics(validation.examples,
                                  opts.device,
                                  opts.n_attributes,
                                  opts.n_values,
                                  opts.vocab_size + 1,
                                  freq=opts.stats_freq)

    loaders = []
    loaders.append(("generalization hold out", generalization_holdout_loader,
                    DiffLoss(opts.n_attributes,
                             opts.n_values,
                             generalization=True)))
    loaders.append(("uniform holdout", uniform_holdout_loader,
                    DiffLoss(opts.n_attributes, opts.n_values)))

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                         validation=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=False),
                               early_stopper, metrics_evaluator,
                               holdout_evaluator
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    validation_acc = early_stopper.validation_stats[-1][1]['acc']
    uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc']

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random
            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(sender,
                                                   receiver,
                                                   loss,
                                                   sender_entropy_coeff=0.0,
                                                   receiver_entropy_coeff=0.0)
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                                 validation=True)

            trainer = core.Trainer(game=game,
                                   optimizer=optimizer,
                                   train_data=train_loader,
                                   validation_data=validation_loader,
                                   callbacks=[
                                       early_stopper,
                                       Evaluator(loaders, opts.device, freq=0)
                                   ])
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]['acc'] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru')

        def small_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru')

        def tiny_gru_receiver_generator():            return \
core.MMRnnReceiverDeterministic(
                MMReceiver(n_hidden=25, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=25, cell='gru')

        def nonlinear_receiver_generator():            return \
MMNonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1,
            max_length=opts.max_len, n_hidden=opts.receiver_hidden)

        for name, receiver_generator in [
            ('gru', gru_receiver_generator),
            ('nonlinear', nonlinear_receiver_generator),
            ('tiny_gru', tiny_gru_receiver_generator),
            ('small_gru', small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps({
                        "mode": "reset",
                        "seed": seed,
                        "receiver_name": name,
                        "auc": auc
                    }))

    print('---End--')

    core.close()
示例#5
0
def main(params):
    import copy
    opts = get_params(params)
    device = opts.device

    train_loader, validation_loader = get_dsprites_dataloader(
        path_to_data='egg/zoo/data_loaders/data/dsprites.npz',
        batch_size=opts.batch_size,
        subsample=opts.subsample,
        image=False)

    n_dim = opts.n_attributes

    if opts.receiver_cell in ['lstm', 'rnn', 'gru']:
        receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
        receiver = core.RnnReceiverDeterministic(receiver,
                                                 opts.vocab_size + 1,
                                                 opts.receiver_emb,
                                                 opts.receiver_hidden,
                                                 cell=opts.receiver_cell)
    else:
        raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}')

    if opts.sender_cell in ['lstm', 'rnn', 'gru']:
        sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        sender = core.RnnSenderReinforce(agent=sender,
                                         vocab_size=opts.vocab_size,
                                         embed_dim=opts.sender_emb,
                                         hidden_size=opts.sender_hidden,
                                         max_len=opts.max_len,
                                         force_eos=False,
                                         cell=opts.sender_cell)
    else:
        raise ValueError(f'Unknown sender cell, {opts.sender_cell}')

    sender = PlusOneWrapper(sender)
    loss = DiffLoss(opts.n_attributes, opts.n_values)

    baseline = {
        'no': core.baselines.NoBaseline,
        'mean': core.baselines.MeanBaseline,
        'builtin': core.baselines.BuiltInBaseline
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline)
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    latent_values, _ = zip(*[batch for batch in validation_loader])
    latent_values = latent_values[0]

    metrics_evaluator = Metrics(latent_values,
                                opts.device,
                                opts.n_attributes,
                                opts.n_values,
                                opts.vocab_size + 1,
                                freq=opts.stats_freq)

    loaders = []
    #loaders.append(("generalization hold out", generalization_holdout_loader, DiffLoss(
    #opts.n_attributes, opts.n_values, generalization=True)))
    loaders.append(("uniform holdout", validation_loader,
                    DiffLoss(opts.n_attributes, opts.n_values)))

    holdout_evaluator = Evaluator(loaders, opts.device, freq=0)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                         validation=True)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=validation_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=False),
                               early_stopper, metrics_evaluator,
                               holdout_evaluator
                           ])
    trainer.train(n_epochs=opts.n_epochs)

    last_epoch_interaction = early_stopper.validation_stats[-1][1]
    validation_acc = last_epoch_interaction.aux['acc'].mean()

    uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc']

    # Train new agents
    if validation_acc > 0.99:

        def _set_seed(seed):
            import random
            import numpy as np

            random.seed(seed)
            torch.manual_seed(seed)
            np.random.seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

        core.get_opts().preemptable = False
        core.get_opts().checkpoint_path = None

        # freeze Sender and probe how fast a simple Receiver will learn the thing
        def retrain_receiver(receiver_generator, sender):
            receiver = receiver_generator()
            game = core.SenderReceiverRnnReinforce(sender,
                                                   receiver,
                                                   loss,
                                                   sender_entropy_coeff=0.0,
                                                   receiver_entropy_coeff=0.0)
            optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
            early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr,
                                                 validation=True)

            trainer = core.Trainer(game=game,
                                   optimizer=optimizer,
                                   train_data=train_loader,
                                   validation_data=validation_loader,
                                   callbacks=[
                                       early_stopper,
                                       Evaluator(loaders, opts.device, freq=0)
                                   ])
            trainer.train(n_epochs=opts.n_epochs // 2)

            accs = [x[1]['acc'] for x in early_stopper.validation_stats]
            return accs

        frozen_sender = Freezer(copy.deepcopy(sender))

        def gru_receiver_generator():            return \
core.RnnReceiverDeterministic(Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru')

        def small_gru_receiver_generator():            return \
core.RnnReceiverDeterministic(
                Receiver(n_hidden=100, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=100, cell='gru')

        def tiny_gru_receiver_generator():            return \
core.RnnReceiverDeterministic(
                Receiver(n_hidden=50, n_outputs=n_dim),
                opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru')

        def nonlinear_receiver_generator():            return \
NonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1,
            max_length=opts.max_len, n_hidden=opts.receiver_hidden)

        for name, receiver_generator in [
            ('gru', gru_receiver_generator),
            ('nonlinear', nonlinear_receiver_generator),
            ('tiny_gru', tiny_gru_receiver_generator),
            ('small_gru', small_gru_receiver_generator),
        ]:

            for seed in range(17, 17 + 3):
                _set_seed(seed)
                accs = retrain_receiver(receiver_generator, frozen_sender)
                accs += [1.0] * (opts.n_epochs // 2 - len(accs))
                auc = sum(accs)
                print(
                    json.dumps({
                        "mode": "reset",
                        "seed": seed,
                        "receiver_name": name,
                        "auc": auc
                    }))

    print('---End--')

    core.close()