コード例 #1
0
ファイル: main.py プロジェクト: yyy11178/GSULC
def main():
    # step = args.step
    print('===> About training in a two-step process! ===')
    print('------\n' 'drop rate: [{}]\t' '\n------'.format(drop_rate))

    # step 1: only train the fc layer
    if step == 1:
        print('===> Step 1 ...')
        bnn = BCNN(pretrained=True, n_classes=num_classes)
        bnn = nn.DataParallel(bnn).cuda()
        optimizer = optim.Adam(bnn.module.fc.parameters(),
                               lr=learning_rate,
                               weight_decay=weight_decay)
    # step 1: train the whole network
    elif step == 2:
        print('===> Step 2 ...')
        bnn = BCNN(pretrained=False, n_classes=num_classes)
        bnn = nn.DataParallel(bnn).cuda()
        optimizer = optim.Adam(bnn.parameters(),
                               lr=learning_rate,
                               weight_decay=weight_decay)
    else:
        raise AssertionError('Wrong step argument')

    correcter = self_correcter.Correcter(num_train_images, num_classes,
                                         queue_size)

    loadmodel = 'checkpoint.pth'

    # check if it is resume mode
    print(
        '-----------------------------------------------------------------------------'
    )
    if resume:
        assert os.path.isfile(
            loadmodel), 'please make sure checkpoint.pth exists'
        print('---> loading checkpoint.pth <---')
        checkpoint = torch.load(loadmodel)
        assert step == checkpoint[
            'step'], 'step in checkpoint does not match step in argument'
        start_epoch = checkpoint['epoch']
        best_accuracy = checkpoint['best_accuracy']
        best_epoch = checkpoint['best_epoch']
        bnn.load_state_dict(checkpoint['bnn_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        correcter.all_predictions = (checkpoint['all_predictions'])
        correcter.softmax_record = (checkpoint['softmax_record'])
        correcter.update_counters = (checkpoint['update_counters'])

    else:
        if step == 2:
            print('--->        step2 checkpoint loaded         <---')
            bnn.load_state_dict(
                torch.load('model/bnn_step1_vgg16_best_epoch.pth'))
        else:
            print('--->        no checkpoint loaded         <---')

        start_epoch = 0
        best_accuracy = 0.0
        best_epoch = None

    print(
        '-----------------------------------------------------------------------------'
    )

    with open(logfile, "a") as f:
        f.write('------ Step: {} ...\n'.format(step))
        f.write('------\n'
                'drop rate: [{}]\tqueue_size: [{}]\t'
                'warm_up: [{}]\tinit_lr: [{}]\t'
                '\n'.format(drop_rate, queue_size, warm_up, learning_rate))

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='max',
                                                     factor=0.5,
                                                     patience=4,
                                                     verbose=True,
                                                     threshold=learning_rate *
                                                     1e-3)

    for epoch in range(start_epoch, num_epochs):
        epoch_start_time = time.time()

        bnn.train()

        if epoch < warm_up:
            warm = True
        else:
            warm = False

        if not warm:
            correcter.separate_clean_and_unclean_keys(drop_rate)
            print("干净的样本数:", len(correcter.clean_key))

        train_acc, train_total = train(train_loader,
                                       epoch,
                                       bnn,
                                       optimizer,
                                       warm,
                                       correcter=correcter)

        test_acc = evaluate(test_loader, bnn)
        if not warm:
            scheduler.step(test_acc)

        if test_acc > best_accuracy:
            best_accuracy = test_acc
            best_epoch = epoch + 1
            torch.save(bnn.state_dict(),
                       'model/bnn_step{}_vgg16_best_epoch.pth'.format(step))

        epoch_end_time = time.time()
        print("all_predictions", len(correcter.all_predictions[0]))
        print("update_counters", correcter.update_counters[0])
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'bnn_state_dict': bnn.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_epoch': best_epoch,
                'best_accuracy': best_accuracy,
                'step': step,
                'all_predictions': correcter.all_predictions,
                'softmax_record': correcter.softmax_record,
                'update_counters': correcter.update_counters
            },
            filename=loadmodel)

        print('------\n'
              'Epoch: [{:03d}/{:03d}]\tTrain Accuracy: [{:6.2f}]\t'
              'Test Accuracy: [{:6.2f}]\t'
              'Epoch Runtime: [{:6.2f}]\t'\
              '\n------'.format(
            epoch + 1, num_epochs, train_acc, test_acc,
            epoch_end_time - epoch_start_time))
        with open(logfile, "a") as f:
            output = 'Epoch: [{:03d}/{:03d}]\tTrain Accuracy: [{:6.2f}]\t' \
                     'Test Accuracy: [{:6.2f}]\t' \
                     'Epoch Runtime: [{:7.2f}]\tTrain_total[{:06d}]\tclean_key[{:06d}]'.format(
                epoch + 1, num_epochs, train_acc, test_acc,
                epoch_end_time - epoch_start_time,train_total,len(correcter.clean_key))
            f.write(output + "\n")

    print('******\n'
          'Best Accuracy 1: [{0:6.2f}], at Epoch [{1:03d}] '
          '\n******'.format(best_accuracy, best_epoch))
    with open(logfile, "a") as f:
        output = '******\n' \
                 'Best Accuracy 1: [{0:6.2f}], at Epoch [{1:03d}]; ' \
                 '\n******'.format(best_accuracy, best_epoch)
        f.write(output + "\n")
コード例 #2
0
ファイル: train.py プロジェクト: ljm198134/Softly-Update-Drop
def main():
    # step = args.step
    print('===> About training in a two-step process! ===')
    print('------\n'
          'drop rate: [{}]\tT_k: [{}]\t'
          'start epoch: [{}]\t'
          '\n------'.format(drop_rate, T_k, start))
    # step 1: only train the fc layer
    if step == 1:
        print('===> Step 1 ...')
        bnn = BCNN(pretrained=True, n_classes=num_classes)
        bnn = nn.DataParallel(bnn).cuda()
        optimizer = optim.Adam(bnn.module.fc.parameters(),
                               lr=learning_rate,
                               weight_decay=weight_decay)
    # step 1: train the whole network
    elif step == 2:
        print('===> Step 2 ...')
        bnn = BCNN(pretrained=False, n_classes=num_classes)
        bnn = nn.DataParallel(bnn).cuda()
        optimizer = optim.Adam(bnn.parameters(),
                               lr=learning_rate,
                               weight_decay=weight_decay)
    else:
        raise AssertionError('Wrong step argument')

    loadmodel = 'checkpoint.pth'
    # check if it is resume mode
    print(
        '-----------------------------------------------------------------------------'
    )
    if resume:
        assert os.path.isfile(
            loadmodel), 'please make sure checkpoint.pth exists'
        print('---> loading checkpoint.pth <---')
        checkpoint = torch.load(loadmodel)
        assert step == checkpoint[
            'step'], 'step in checkpoint does not match step in argument'
        start_epoch = checkpoint['epoch']
        best_accuracy = checkpoint['best_accuracy']
        best_epoch = checkpoint['best_epoch']
        bnn.load_state_dict(checkpoint['bnn_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        Cross_entropy = checkpoint['Cross_entropy']
        logits_softmax = checkpoint['logits_softmax']
    else:
        if step == 2:
            print('--->        step2 checkpoint loaded         <---')
            bnn.load_state_dict(
                torch.load('model/bnn_step1_vgg16_best_epoch.pth'))
        else:
            print('--->        no checkpoint loaded         <---')
        Cross_entropy = []
        logits_softmax = []
        start_epoch = 0
        best_accuracy = 0.0
        best_epoch = None
    print(
        '-----------------------------------------------------------------------------'
    )

    with open(logfile, "a") as f:
        f.write('------ Step: {} ...\n'.format(step))
        f.write('------\n'
                'drop rate: [{}]\tT_k: [{}]\t'
                'start epoch: [{}]\t'
                '\n------'.format(drop_rate, T_k, start))

    for epoch in range(start_epoch, num_epochs):
        epoch_start_time = time.time()

        bnn.train()
        adjust_learning_rate(optimizer, epoch)

        #train returns 'Cross_entropy', used in saving checkpoints.
        train_acc, logits_softmax, Cross_entropy = train(
            train_loader, epoch, bnn, optimizer, logits_softmax, Cross_entropy)

        # dump the output: cross_entropy, image path, image label, image id. If you want to check the selection result, just use the code blow.
        # if len(Cross_entropy) > 0:
        #     pickle.dump(Cross_entropy, open(cross_entropy_savapath + 'crossentropy_epoch{}_step{}.pkl'.format(epoch + 1,step), 'wb'))

        test_acc = evaluate(test_loader, bnn)

        if test_acc > best_accuracy:
            best_accuracy = test_acc
            best_epoch = epoch + 1
            torch.save(bnn.state_dict(),
                       'model/bnn_step{}_vgg16_best_epoch.pth'.format(step))

        epoch_end_time = time.time()
        # save checkpoint
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'bnn_state_dict': bnn.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_epoch': best_epoch,
                'best_accuracy': best_accuracy,
                'step': step,
                'Cross_entropy': Cross_entropy,
                'logits_softmax': logits_softmax
            },
            filename=loadmodel)

        print('------\n'
              'Epoch: [{:03d}/{:03d}]\tTrain Accuracy: [{:6.2f}]\t'
              'Test Accuracy: [{:6.2f}]\t'
              'Epoch Runtime: [{:6.2f}]\t'\
              '\n------'.format(
            epoch + 1, num_epochs, train_acc, test_acc,
            epoch_end_time - epoch_start_time))
        with open(logfile, "a") as f:
            output = 'Epoch: [{:03d}/{:03d}]\tTrain Accuracy: [{:6.2f}]\t' \
                     'Test Accuracy: [{:6.2f}]\t' \
                     'Epoch Runtime: [{:6.2f}]\t'.format(
                epoch + 1, num_epochs, train_acc, test_acc,
                epoch_end_time - epoch_start_time)
            f.write(output + "\n")

    print('******\n'
          'Best Accuracy 1: [{0:6.2f}], at Epoch [{1:03d}] '
          '\n******'.format(best_accuracy, best_epoch))
    with open(logfile, "a") as f:
        output = '******\n' \
                 'Best Accuracy 1: [{0:6.2f}], at Epoch [{1:03d}]; ' \
                 '\n******'.format(best_accuracy, best_epoch)
        f.write(output + "\n")