示例#1
0
def test_early_stopping():
    game, data = MockGame(), Dataset()
    early_stopper = core.EarlyStopperAccuracy(threshold=0.9)
    trainer = core.Trainer(game=game,
                           optimizer=torch.optim.Adam(game.parameters()),
                           train_data=data,
                           validation_data=data,
                           callbacks=[early_stopper])
    trainer.train(1)
    assert trainer.should_stop
示例#2
0
def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    n_a, n_v = opts.n_a, opts.n_v
    opts.vocab_size = n_v

    train_data = AttributeValueData(n_attributes=n_a,
                                    n_values=n_v,
                                    mul=1,
                                    mode='train')
    train_loader = DataLoader(train_data,
                              batch_size=opts.batch_size,
                              shuffle=True)

    test_data = AttributeValueData(n_attributes=n_a,
                                   n_values=n_v,
                                   mul=1,
                                   mode='test')
    test_loader = DataLoader(test_data,
                             batch_size=opts.batch_size,
                             shuffle=False)

    print(f'# Size of train {len(train_data)} test {len(test_data)}')

    if opts.language == 'identity':
        sender = IdentitySender(n_a, n_v)
    elif opts.language == 'rotated':
        sender = RotatedSender(n_a, n_v)
    else:
        assert False

    receiver = Receiver(n_hidden=opts.receiver_hidden,
                        n_dim=n_a * n_v,
                        inner_layers=opts.receiver_layers)
    receiver = core.RnnReceiverDeterministic(
        receiver,
        opts.vocab_size + 1,  # exclude eos = 0
        opts.receiver_emb,
        opts.receiver_hidden,
        cell=opts.receiver_cell,
        num_layers=opts.cell_layers)

    diff_loss = DiffLoss(n_a, n_v, loss_type=opts.loss_type)

    game = core.SenderReceiverRnnReinforce(sender,
                                           receiver,
                                           diff_loss,
                                           receiver_entropy_coeff=0.05,
                                           sender_entropy_coeff=0.0)

    optimizer = core.build_optimizer(receiver.parameters())
    loss = game.loss

    early_stopper = core.EarlyStopperAccuracy(1.0, validation=False)

    trainer = core.Trainer(game=game,
                           optimizer=optimizer,
                           train_data=train_loader,
                           validation_data=test_loader,
                           callbacks=[
                               core.ConsoleLogger(as_json=True,
                                                  print_train_loss=True),
                               early_stopper
                           ],
                           grad_norm=1.0)

    trainer.train(n_epochs=opts.n_epochs)
    core.close()