def __init__(self, opts): super(OrigSenderDeterministic, self).__init__() n_dim = opts.n_attributes * opts.n_values sender = Sender(n_inputs=n_dim, n_hidden=opts.hidden) self.sender = RnnSenderDeterministic( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.hidden, max_len=opts.max_len, cell="gru", )
def __init__(self, opts): super(OrigSender, self).__init__() n_dim = opts.n_attributes * opts.n_values sender = Sender(n_inputs=n_dim, n_hidden=opts.hidden) sender = core.RnnSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.hidden, max_len=opts.max_len, cell="gru", ) self.sender = PlusOneWrapper(sender)
def main(params): import copy opts = get_params(params) device = opts.device full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values) if opts.density_data > 0: sampled_data = select_subset_V2( full_data, opts.density_data, opts.n_attributes, opts.n_values ) full_data = copy.deepcopy(sampled_data) train, generalization_holdout = split_holdout(full_data) train, uniform_holdout = split_train_test(train, 0.1) generalization_holdout, train, uniform_holdout, full_data = [ one_hotify(x, opts.n_attributes, opts.n_values) for x in [generalization_holdout, train, uniform_holdout, full_data] ] train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(train, 1) generalization_holdout, uniform_holdout, full_data = ( ScaledDataset(generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(full_data), ) generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [ DataLoader(x, batch_size=opts.batch_size) for x in [generalization_holdout, uniform_holdout, full_data] ] train_loader = DataLoader(train, batch_size=opts.batch_size) validation_loader = DataLoader(validation, batch_size=len(validation)) n_dim = opts.n_attributes * opts.n_values if opts.receiver_cell in ["lstm", "rnn", "gru"]: receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim) receiver = core.RnnReceiverDeterministic( receiver, opts.vocab_size + 1, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell, ) else: raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}") if opts.sender_cell in ["lstm", "rnn", "gru"]: sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, cell=opts.sender_cell, ) else: raise ValueError(f"Unknown sender cell, {opts.sender_cell}") sender = PlusOneWrapper(sender) loss = DiffLoss(opts.n_attributes, opts.n_values) baseline = { "no": core.baselines.NoBaseline, "mean": core.baselines.MeanBaseline, "builtin": core.baselines.BuiltInBaseline, }[opts.baseline] game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=baseline, ) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) metrics_evaluator = Metrics( validation.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq, ) loaders = [] loaders.append( ( "generalization hold out", generalization_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values, generalization=True), ) ) loaders.append( ( "uniform holdout", uniform_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values), ) ) holdout_evaluator = Evaluator(loaders, opts.device, freq=0) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, metrics_evaluator, holdout_evaluator, ], ) trainer.train(n_epochs=opts.n_epochs) last_epoch_interaction = early_stopper.validation_stats[-1][1] validation_acc = last_epoch_interaction.aux["acc"].mean() uniformtest_acc = holdout_evaluator.results["uniform holdout"]["acc"] # Train new agents if validation_acc > 0.99: def _set_seed(seed): import random import numpy as np random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) core.get_opts().preemptable = False core.get_opts().checkpoint_path = None # freeze Sender and probe how fast a simple Receiver will learn the thing def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0, ) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy( opts.early_stopping_thr, validation=True ) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)], ) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]["acc"] for x in early_stopper.validation_stats] return accs frozen_sender = Freezer(copy.deepcopy(sender)) def gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell="gru", ) def small_gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=100, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=100, cell="gru", ) def tiny_gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=50, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell="gru", ) def nonlinear_receiver_generator(): return NonLinearReceiver( n_outputs=n_dim, vocab_size=opts.vocab_size + 1, max_length=opts.max_len, n_hidden=opts.receiver_hidden, ) for name, receiver_generator in [ ("gru", gru_receiver_generator), ("nonlinear", nonlinear_receiver_generator), ("tiny_gru", tiny_gru_receiver_generator), ("small_gru", small_gru_receiver_generator), ]: for seed in range(17, 17 + 3): _set_seed(seed) accs = retrain_receiver(receiver_generator, frozen_sender) accs += [1.0] * (opts.n_epochs // 2 - len(accs)) auc = sum(accs) print( json.dumps( { "mode": "reset", "seed": seed, "receiver_name": name, "auc": auc, } ) ) print("---End--") core.close()