Example #1
0
    t0 = time()
    utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, epochs, lr_start))
    net.train()
    training_loss = 0
    accs = []
    steps = 0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        steps += 1
        inputs, labels = Variable(inputs.cuda(async=True)), Variable(labels.cuda(async=True))

        optimizer.zero_grad()

        outputs = Variable(torch.zeros(inputs.shape[0], net.num_classes, flags.train_ens).cuda())
        for j in range(flags.train_ens):
            outputs[:, :, j] = F.log_softmax(net(inputs), dim=1)
        log_outputs = utils.logmeanexp(outputs, dim=2)

        loss = criterion(log_outputs, labels)
        loss.backward()
        optimizer.step()

        accs.append(metrics.logit2acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()[0]

    logger.add(epoch, tr_loss=training_loss/steps, tr_acc=np.mean(accs))

    # Ens 100 test
    net.train()
    acc, nll = utils.evaluate(net, testloader, num_ens=100)
    logger.add(epoch, te_nll_ens100=nll, te_acc_ens100=acc)
    t0 = time()
    utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, epochs, lr_start))
    net.train()
    training_loss = 0
    accs = []
    steps = 0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        steps += 1
        inputs, labels = Variable(inputs.cuda(async=True)), Variable(labels.cuda(async=True))

        optimizer.zero_grad()

        outputs = Variable(torch.zeros(inputs.shape[0], net.num_classes, flags.train_ens).cuda())
        for j in range(flags.train_ens):
            outputs[:, :, j] = F.log_softmax(net(inputs), dim=1)
        log_outputs = utils.logmeanexp(outputs, dim=2)

        loss = criterion(log_outputs, labels)
        loss.backward()
        optimizer.step()

        accs.append(metrics.logit2acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()[0]

    logger.add(epoch, tr_loss=training_loss/steps, tr_acc=np.mean(accs))

    # Ens 100 test
    net.train()
    acc, nll = utils.evaluate(net, testloader, num_ens=100)
    logger.add(epoch, te_nll_ens100=nll, te_acc_ens100=acc)