Exemplo n.º 1
0
    def test_MultiCriterion(self):
        input = torch.rand(2, 10)
        target = torch.LongTensor((1, 8))
        nll = nn.ClassNLLCriterion()
        nll2 = nn.CrossEntropyCriterion()
        mc = nn.MultiCriterion().add(nll, 0.5).add(nll2)

        output = mc.forward(input, target)
        output2 = nll.forward(input, target) / 2 + nll2.forward(input, target)

        self.assertEqual(output, output2)
        gradInput = mc.backward(input, target)
        gradInput2 = nll.backward(input, target).clone().div(2).add(
            nll2.backward(input, target))
        self.assertEqual(gradInput, gradInput2)

        # test type
        mc.float()
        gradInput = gradInput.clone()
        input3 = input.float()
        target3 = target
        output3 = mc.forward(input3, target3)
        gradInput3 = mc.backward(input3, target3)
        self.assertEqual(output, output3)
        self.assertEqual(gradInput.float(), gradInput3)

        # Check that these don't raise errors
        mc.__repr__()
        str(mc)
def ClassificationTrainValidate(model, dataset, p):
    t = model.type()
    if 'nEpochs' not in p:
        p['nEpochs'] = 100
    if 'initial_LR' not in p:
        p['initial_LR'] = 1e-1
    if 'LR_decay' not in p:
        p['LR_decay'] = 4e-2
    if 'weightDecay' not in p:
        p['weightDecay'] = 1e-4
    if 'momentum' not in p:
        p['momentum'] = 0.9
    if 'checkPoint' not in p:
        p['checkPoint'] = False
    optimState = {
        'learningRate': p['initial_LR'],
        'learningRateDecay': 0.0,
        'momentum': p['momentum'],
        'nesterov': True,
        'dampening': 0.0,
        'weightDecay': p['weightDecay'],
        'epoch': 1
    }
    if os.path.isfile('epoch.pth'):
        optimState['epoch'] = torch.load('epoch.pth') + 1
        print('Restarting at epoch ' +
              str(optimState['epoch']) +
              ' from model.pickle ..')
        model = torch.load('model.pth')

    print(p)
    criterion = nn.CrossEntropyCriterion()
    criterion.type(model.type())
    params, gradParams = model.flattenParameters()
    print('#parameters', params.nelement())
    for epoch in range(optimState['epoch'], p['nEpochs'] + 1):
        model.training()
        stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
        optimState['learningRate'] = p['initial_LR'] * \
            math.exp((1 - epoch) * p['LR_decay'])
        start = time.time()
        for batch in dataset['train']():
            batch['input'].type(t)
            batch['target'] = batch['target'].type(t)
            model.forward(batch['input'])
            criterion.forward(model.output, batch['target'])
            updateStats(stats, model.output, batch['target'], criterion.output)
            gradParams.zero_()  # model:zeroGradParameters()
            criterion.backward(model.output, batch['target'])
            model.backward(batch['input'], criterion.gradInput)

            def feval(x):
                return criterion.output, gradParams
            optim.sgd(feval, params, optimState)
        print(epoch, 'train: top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' %
              (100 *
               (1 -
                1.0 * stats['top1'] /
                   stats['n']), 100 *
                  (1 -
                   1.0 * stats['top5'] /
                   stats['n']), stats['nll'] /
                  stats['n'], time.time() -
                  start))

        if p['checkPoint']:
            model.modules[0].clearState()
            torch.save(model, 'model.pth')
            torch.save(epoch, 'epoch.pth')

        model.evaluate()
        s.forward_pass_multiplyAdd_count = 0
        s.forward_pass_hidden_states = 0
        stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
        start = time.time()
        for batch in dataset['val']():
            batch['input'].type(t)
            batch['target'] = batch['target'].type(t)
            model.forward(batch['input'])
            criterion.forward(model.output, batch['target'])
            updateStats(stats, model.output, batch['target'], criterion.output)
        print(epoch, 'test:  top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' %
              (100 *
               (1 -
                1.0 * stats['top1'] /
                   stats['n']), 100 *
                  (1 -
                   1.0 * stats['top5'] /
                   stats['n']), stats['nll'] /
                  stats['n'], time.time() -
                  start))
        print(
            '%.3e MultiplyAdds/sample %.3e HiddenStates/sample' %
            (s.forward_pass_multiplyAdd_count /
             stats['n'],
                s.forward_pass_hidden_states /
                stats['n']))