Example #1
0
def train(dataloader):
    print('Start training...')
    tic = time.time()

    cnn = CNN().to(device)
    cnn.train()
    cnn.apply(weights_init)

    criterion = nn.MultiLabelSoftMarginLoss()
    optimizer = optim.AdamW(cnn.parameters(), lr=learning_rate)

    losslog = open(os.path.join(PATH, 'loss/loss.txt'), 'a')  #存储损失
    loss_list = []

    with alive_bar(iters) as bar:
        for epoch in range(iters):
            bar()
            img, labels = next(dataloader)
            img = img.to(device)
            labels = labels.to(device)

            predicted_labels = cnn(img)

            loss = criterion(predicted_labels.double(), labels.double())
            loss_list.append(loss.item())

            if (epoch + 1) % (iters / 100) == 0:
                losslog.write('loss at iter ' + str(epoch + 1) + ':   ' +
                              str(loss.item()) + '\n')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    save(cnn.state_dict(), os.path.join(PATH, save_dir))
    losslog.close()
    with shelve.open(os.path.join(PATH, "loss/lossdata")) as d:
        d['loss'] = loss_list
        d['iter'] = list(range(iters))

    print("model saved.")