def test_early_stopping(): game, data = MockGame(), Dataset() early_stopper = core.EarlyStopperAccuracy(threshold=0.9) trainer = core.Trainer(game=game, optimizer=torch.optim.Adam(game.parameters()), train_data=data, validation_data=data, callbacks=[early_stopper]) trainer.train(1) assert trainer.should_stop
def main(params): opts = get_params(params) print(opts) device = opts.device n_a, n_v = opts.n_a, opts.n_v opts.vocab_size = n_v train_data = AttributeValueData(n_attributes=n_a, n_values=n_v, mul=1, mode='train') train_loader = DataLoader(train_data, batch_size=opts.batch_size, shuffle=True) test_data = AttributeValueData(n_attributes=n_a, n_values=n_v, mul=1, mode='test') test_loader = DataLoader(test_data, batch_size=opts.batch_size, shuffle=False) print(f'# Size of train {len(train_data)} test {len(test_data)}') if opts.language == 'identity': sender = IdentitySender(n_a, n_v) elif opts.language == 'rotated': sender = RotatedSender(n_a, n_v) else: assert False receiver = Receiver(n_hidden=opts.receiver_hidden, n_dim=n_a * n_v, inner_layers=opts.receiver_layers) receiver = core.RnnReceiverDeterministic( receiver, opts.vocab_size + 1, # exclude eos = 0 opts.receiver_emb, opts.receiver_hidden, cell=opts.receiver_cell, num_layers=opts.cell_layers) diff_loss = DiffLoss(n_a, n_v, loss_type=opts.loss_type) game = core.SenderReceiverRnnReinforce(sender, receiver, diff_loss, receiver_entropy_coeff=0.05, sender_entropy_coeff=0.0) optimizer = core.build_optimizer(receiver.parameters()) loss = game.loss early_stopper = core.EarlyStopperAccuracy(1.0, validation=False) trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader, validation_data=test_loader, callbacks=[ core.ConsoleLogger(as_json=True, print_train_loss=True), early_stopper ], grad_norm=1.0) trainer.train(n_epochs=opts.n_epochs) core.close()