Ejemplo n.º 1
0
def test_create_supervised_tbptt_trainer_callcounts(mock_detach_hidden):
    # Mocking objects
    model = mock.MagicMock()
    # Necessary to unpack output
    model.return_value = (1, 1)
    optimizer = mock.MagicMock()
    loss = mock.MagicMock()

    trainer = create_supervised_tbptt_trainer(model, optimizer, loss, tbtt_step=2)

    # Adding two mock handles to the trainer to monitor that TBPTT events are
    # called correctly
    handle_started = mock.MagicMock()
    trainer.add_event_handler(Tbptt_Events.TIME_ITERATION_STARTED, handle_started)
    handle_completed = mock.MagicMock()
    trainer.add_event_handler(Tbptt_Events.TIME_ITERATION_COMPLETED, handle_completed)

    # Fake data
    X = torch.ones(6, 2, 1)
    y = torch.ones(6, 2, 1)
    data = [(X, y)]

    # Running trainer
    trainer.run(data)

    # Verifications
    assert handle_started.call_count == 3
    assert handle_completed.call_count == 3
    assert mock_detach_hidden.call_count == 2
    assert model.call_count == 3
    assert loss.call_count == 3
    assert optimizer.zero_grad.call_count == 3
    assert optimizer.step.call_count == 3
    n_args_tuple = tuple(len(args) for args, kwargs in model.call_args_list)
    assert n_args_tuple == (1, 2, 2)
Ejemplo n.º 2
0
def _test_create_supervised_tbptt_trainer(device):
    # Defining dummy recurrent model with zero weights
    model = nn.RNN(1, 1, bias=False)
    for p in model.parameters():
        p.data.zero_()

    # Set some mock on forward to monitor
    forward_mock = mock.MagicMock()
    forward_mock.return_value = None
    model.register_forward_hook(forward_mock)

    # Defning optimizer and trainer
    optimizer = optim.SGD(model.parameters(), 1)
    trainer = create_supervised_tbptt_trainer(
        model,
        optimizer,
        F.mse_loss,
        tbtt_step=2,
        device=device
    )

    # Fake data
    X = torch.ones(6, 2, 1)
    y = torch.ones(6, 2, 1)
    data = [(X, y)]

    # Running trainer
    trainer.run(data)

    # If tbptt is not use (one gradient update), the hidden to hidden weight
    # should stay zero
    assert not model.weight_hh_l0.item() == pytest.approx(0)

    # Cheking forward calls
    assert forward_mock.call_count == 3
    for i in range(3):
        inputs = forward_mock.call_args_list[i][0][1]
        if i == 0:
            assert len(inputs) == 1
        else:
            assert len(inputs) == 2
            x, h = inputs
            assert h.is_leaf
Ejemplo n.º 3
0
def _test_create_supervised_tbptt_trainer(device):
    # Defining dummy recurrent model with zero weights
    model = nn.RNN(1, 1, bias=False)
    for p in model.parameters():
        p.data.zero_()

    # Defning optimizer and trainer
    optimizer = optim.SGD(model.parameters(), 1)
    trainer = create_supervised_tbptt_trainer(model,
                                              optimizer,
                                              F.mse_loss,
                                              tbtt_step=2,
                                              device=device)

    # Adding two mock handles to the trainer to monitor that TBPTT events are
    # called correctly
    handle_started = MagicMock()
    trainer.add_event_handler(Tbptt_Events.TIME_ITERATION_STARTED,
                              handle_started)
    handle_completed = MagicMock()
    trainer.add_event_handler(Tbptt_Events.TIME_ITERATION_COMPLETED,
                              handle_completed)

    # Fake data
    X = torch.ones(6, 2, 1)
    y = X = torch.ones(6, 2, 1)
    data = [(X, y)]

    # Running trainer
    trainer.run(data)

    # Verifications
    assert handle_started.call_count == 3
    assert handle_completed.call_count == 3

    # If tbptt is not use (one gradient update), the hidden to hidden weight
    # should stay zero
    assert not model.weight_hh_l0.item() == pytest.approx(0)
Ejemplo n.º 4
0
    num_val_batches = len(valdl)
    num_val_examples = num_val_batches * BATCH_SIZE

    print("Num train examples: {} ({} batches)".format(num_train_examples,
                                                       num_train_batches))
    print("Num validation examples: {} ({} batches)".format(
        num_val_examples, num_val_batches))

    print("\nStarting training for {} epochs...\n".format(NUM_EPOCHS))

    print(type(traindl))

    # create ignite trainer
    trainer = create_supervised_tbptt_trainer(model,
                                              optimizer,
                                              loss_function,
                                              tbtt_step=TBTT_STEP)

    evaluator = create_supervised_tbptt_evaluator(
        model,
        metrics={
            'accuracy': Accuracy(),
            'nll': Loss(loss_function),
            'precision':
            Precision(output_transform=thresholded_output_transform),
            'recall': Recall(output_transform=thresholded_output_transform)
        })

    # evaluator = create_supervised_evaluator(model, metrics=['accuracy'])

    @trainer.on(Events.ITERATION_COMPLETED)