def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0, ) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy( opts.early_stopping_thr, validation=True ) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)], ) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]["acc"] for x in early_stopper.validation_stats] return accs
def main(params): import copy opts = get_params(params) device = opts.device full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values) if opts.density_data > 0: sampled_data = select_subset_V2(full_data, opts.density_data, opts.n_attributes, opts.n_values) full_data = copy.deepcopy(sampled_data) train, generalization_holdout = split_holdout(full_data) train, uniform_holdout = split_train_test(train, 0.1) generalization_holdout, train, uniform_holdout, full_data = [ one_hotify(x, opts.n_attributes, opts.n_values) for x in [generalization_holdout, train, uniform_holdout, full_data] ] train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset( train, 1) generalization_holdout, uniform_holdout, full_data = ScaledDataset( generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset( full_data) generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [ DataLoader(x, batch_size=opts.batch_size) for x in [generalization_holdout, uniform_holdout, full_data] ] train_loader = DataLoader(train, batch_size=opts.batch_size) validation_loader = DataLoader(validation, batch_size=len(validation)) n_dim = opts.n_attributes * opts.n_values if opts.receiver_cell in ['lstm', 'rnn', 'gru']: receiver = MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim) receiver = core.MMRnnReceiverDeterministic(receiver, opts.vocab_size + 1, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell) else: raise ValueError(f'Unknown receiver cell, {opts.receiver_cell}') if opts.sender_cell in ['lstm', 'rnn', 'gru']: sender = MMSender(n_inputs=n_dim, n_hidden=opts.sender_hidden) s1 = SplitWrapper(sender, 0) s2 = SplitWrapper(sender, 1) sender1 = core.RnnSenderReinforce(agent=s1, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, force_eos=False, cell=opts.sender_cell) sender1 = PlusNWrapper(sender1, 1) sender2 = core.RnnSenderReinforce(agent=s2, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, force_eos=False, cell=opts.sender_cell) # sender2 = PlusNWrapper(sender2, opts.vocab_size + 1) sender2 = PlusNWrapper(sender2, 1) sender = CombineMMRnnSenderReinforce(sender1, sender2) else: raise ValueError(f'Unknown sender cell, {opts.sender_cell}') loss = DiffLoss(opts.n_attributes, opts.n_values) baseline = { 'no': core.baselines.NoBaseline, 'mean': core.baselines.MeanBaseline, 'builtin': core.baselines.BuiltInBaseline }[opts.baseline] game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=baseline) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) metrics_evaluator = MMMetrics(validation.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq) loaders = [] loaders.append(("generalization hold out", generalization_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values, generalization=True))) loaders.append(("uniform holdout", uniform_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values))) holdout_evaluator = Evaluator(loaders, opts.device, freq=0) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, metrics_evaluator, holdout_evaluator ]) trainer.train(n_epochs=opts.n_epochs) validation_acc = early_stopper.validation_stats[-1][1]['acc'] uniformtest_acc = holdout_evaluator.results['uniform holdout']['acc'] # Train new agents if validation_acc > 0.99: def _set_seed(seed): import random import numpy as np random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) core.get_opts().preemptable = False core.get_opts().checkpoint_path = None # freeze Sender and probe how fast a simple Receiver will learn the thing def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce(sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ early_stopper, Evaluator(loaders, opts.device, freq=0) ]) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]['acc'] for x in early_stopper.validation_stats] return accs frozen_sender = Freezer(copy.deepcopy(sender)) def gru_receiver_generator(): return \ core.MMRnnReceiverDeterministic(MMReceiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell='gru') def small_gru_receiver_generator(): return \ core.MMRnnReceiverDeterministic( MMReceiver(n_hidden=50, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell='gru') def tiny_gru_receiver_generator(): return \ core.MMRnnReceiverDeterministic( MMReceiver(n_hidden=25, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=25, cell='gru') def nonlinear_receiver_generator(): return \ MMNonLinearReceiver(n_outputs=n_dim, vocab_size=opts.vocab_size + 1, max_length=opts.max_len, n_hidden=opts.receiver_hidden) for name, receiver_generator in [ ('gru', gru_receiver_generator), ('nonlinear', nonlinear_receiver_generator), ('tiny_gru', tiny_gru_receiver_generator), ('small_gru', small_gru_receiver_generator), ]: for seed in range(17, 17 + 3): _set_seed(seed) accs = retrain_receiver(receiver_generator, frozen_sender) accs += [1.0] * (opts.n_epochs // 2 - len(accs)) auc = sum(accs) print( json.dumps({ "mode": "reset", "seed": seed, "receiver_name": name, "auc": auc })) print('---End--') core.close()
def main(params): import copy opts = get_params(params) device = opts.device full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values) if opts.density_data > 0: sampled_data = select_subset_V2( full_data, opts.density_data, opts.n_attributes, opts.n_values ) full_data = copy.deepcopy(sampled_data) train, generalization_holdout = split_holdout(full_data) train, uniform_holdout = split_train_test(train, 0.1) generalization_holdout, train, uniform_holdout, full_data = [ one_hotify(x, opts.n_attributes, opts.n_values) for x in [generalization_holdout, train, uniform_holdout, full_data] ] train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset(train, 1) generalization_holdout, uniform_holdout, full_data = ( ScaledDataset(generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(full_data), ) generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [ DataLoader(x, batch_size=opts.batch_size) for x in [generalization_holdout, uniform_holdout, full_data] ] train_loader = DataLoader(train, batch_size=opts.batch_size) validation_loader = DataLoader(validation, batch_size=len(validation)) n_dim = opts.n_attributes * opts.n_values if opts.receiver_cell in ["lstm", "rnn", "gru"]: receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim) receiver = core.RnnReceiverDeterministic( receiver, opts.vocab_size + 1, opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell, ) else: raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}") if opts.sender_cell in ["lstm", "rnn", "gru"]: sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden) sender = core.RnnSenderReinforce( agent=sender, vocab_size=opts.vocab_size, embed_dim=opts.sender_emb, hidden_size=opts.sender_hidden, max_len=opts.max_len, cell=opts.sender_cell, ) else: raise ValueError(f"Unknown sender cell, {opts.sender_cell}") sender = PlusOneWrapper(sender) loss = DiffLoss(opts.n_attributes, opts.n_values) baseline = { "no": core.baselines.NoBaseline, "mean": core.baselines.MeanBaseline, "builtin": core.baselines.BuiltInBaseline, }[opts.baseline] game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=baseline, ) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) metrics_evaluator = Metrics( validation.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq, ) loaders = [] loaders.append( ( "generalization hold out", generalization_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values, generalization=True), ) ) loaders.append( ( "uniform holdout", uniform_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values), ) ) holdout_evaluator = Evaluator(loaders, opts.device, freq=0) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, metrics_evaluator, holdout_evaluator, ], ) trainer.train(n_epochs=opts.n_epochs) last_epoch_interaction = early_stopper.validation_stats[-1][1] validation_acc = last_epoch_interaction.aux["acc"].mean() uniformtest_acc = holdout_evaluator.results["uniform holdout"]["acc"] # Train new agents if validation_acc > 0.99: def _set_seed(seed): import random import numpy as np random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) core.get_opts().preemptable = False core.get_opts().checkpoint_path = None # freeze Sender and probe how fast a simple Receiver will learn the thing def retrain_receiver(receiver_generator, sender): receiver = receiver_generator() game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=0.0, receiver_entropy_coeff=0.0, ) optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr) early_stopper = EarlyStopperAccuracy( opts.early_stopping_thr, validation=True ) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[early_stopper, Evaluator(loaders, opts.device, freq=0)], ) trainer.train(n_epochs=opts.n_epochs // 2) accs = [x[1]["acc"] for x in early_stopper.validation_stats] return accs frozen_sender = Freezer(copy.deepcopy(sender)) def gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=opts.receiver_hidden, cell="gru", ) def small_gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=100, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=100, cell="gru", ) def tiny_gru_receiver_generator(): return core.RnnReceiverDeterministic( Receiver(n_hidden=50, n_outputs=n_dim), opts.vocab_size + 1, opts.receiver_emb, hidden_size=50, cell="gru", ) def nonlinear_receiver_generator(): return NonLinearReceiver( n_outputs=n_dim, vocab_size=opts.vocab_size + 1, max_length=opts.max_len, n_hidden=opts.receiver_hidden, ) for name, receiver_generator in [ ("gru", gru_receiver_generator), ("nonlinear", nonlinear_receiver_generator), ("tiny_gru", tiny_gru_receiver_generator), ("small_gru", small_gru_receiver_generator), ]: for seed in range(17, 17 + 3): _set_seed(seed) accs = retrain_receiver(receiver_generator, frozen_sender) accs += [1.0] * (opts.n_epochs // 2 - len(accs)) auc = sum(accs) print( json.dumps( { "mode": "reset", "seed": seed, "receiver_name": name, "auc": auc, } ) ) print("---End--") core.close()
def main(params): opts = get_params(params) print(opts) full_data, train, uniform_holdout, generalization_holdout = get_data(opts) generalization_holdout, train, uniform_holdout, full_data = [ one_hotify(x, opts.n_attributes, opts.n_values) for x in [generalization_holdout, train, uniform_holdout, full_data] ] train, validation = ScaledDataset(train, opts.data_scaler), ScaledDataset( train, 1) generalization_holdout, uniform_holdout, full_data = ( ScaledDataset(generalization_holdout), ScaledDataset(uniform_holdout), ScaledDataset(full_data), ) generalization_holdout_loader, uniform_holdout_loader, full_data_loader = [ DataLoader(x, batch_size=opts.batch_size) for x in [generalization_holdout, uniform_holdout, full_data] ] train_loader = DataLoader(train, batch_size=opts.batch_size, shuffle=True) validation_loader = DataLoader(validation, batch_size=len(validation)) loss = DiffLoss(opts.n_attributes, opts.n_values) sender = getattr(egg.zoo.compo_vs_generalization_ood.archs, opts.sender)(opts) receiver = getattr(egg.zoo.compo_vs_generalization_ood.archs, opts.receiver)(opts) game = core.SenderReceiverRnnReinforce( sender, receiver, loss, sender_entropy_coeff=opts.sender_entropy_coeff, receiver_entropy_coeff=0.0, length_cost=0.0, baseline_type=core.baselines.MeanBaseline, ) optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr) metrics_evaluator = Metrics( validation.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq, ) metrics_evaluator_generalization_holdout = Metrics( generalization_holdout.examples, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size + 1, freq=opts.stats_freq, ) loaders = [] loaders.append(( "generalization hold out", generalization_holdout_loader, # DiffLoss(opts.n_attributes, opts.n_values, generalization=True), # we don't want to ignore zeros: DiffLoss(opts.n_attributes, opts.n_values, generalization=False), )) loaders.append(( "uniform holdout", uniform_holdout_loader, DiffLoss(opts.n_attributes, opts.n_values), )) holdout_evaluator = Evaluator(loaders, opts.device, freq=1) early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True) trainer = core.Trainer( game=game, optimizer=optimizer, train_data=train_loader, validation_data=validation_loader, callbacks=[ # print validation (i.e. unscaled training data) loss: core.ConsoleLogger(as_json=True, print_train_loss=False), early_stopper, # print compositionality metrics at the end of training # (validation, i.e, unscaled training data): metrics_evaluator, # print compositionality metrics at the end of training (holdout data): metrics_evaluator_generalization_holdout, # print generalization and uniform holdout accuracies at each epoch: holdout_evaluator, ], ) trainer.train(n_epochs=opts.n_epochs) core.close()