예제 #1
0
def main():
    best_acc = 0
    opt = parse_option()

    # build data loader
    train_loader, val_loader = set_loader(opt)

    # build model and criterion
    model, classifier, criterion = set_model(opt)

    # build optimizer
    # optimizer = set_optimizer(opt, [classifier])
    optimizer = set_optimizer(opt, [classifier, model])

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss, acc = train(train_loader, model, classifier, criterion,
                          optimizer, epoch, opt)
        time2 = time.time()
        print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
            epoch, time2 - time1, acc))

        # eval for one epoch
        loss, val_acc = validate(val_loader, model, classifier, criterion, opt)
        if val_acc > best_acc:
            best_acc = val_acc

    print('best accuracy: {:.2f}'.format(best_acc))
def main():
    best_acc = 0
    best_acc5 = 0
    opt = parse_option()

    # build data loader
    train_loader, val_loader = set_loader(opt)

    # build model and criterion
    model, classifier, criterion = set_model(opt)

    # build optimizer
    optimizer = set_optimizer(opt, classifier)

    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss, acc, acc5 = train(train_loader, model, classifier, criterion,
                                optimizer, epoch, opt)
        time2 = time.time()
        logging.info(
            'Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
                epoch, time2 - time1, acc))

        logger.log_value('classifier/train_loss', loss, epoch)
        logger.log_value('classifier/train_acc1', acc, epoch)
        logger.log_value('classifier/train_acc5', acc5, epoch)

        # eval for one epoch
        loss, val_acc, val_acc5 = validate(val_loader, model, classifier,
                                           criterion, opt)
        logger.log_value('classifier/val_loss', loss, epoch)
        logger.log_value('classifier/val_acc1', val_acc, epoch)
        logger.log_value('classifier/val_acc5', val_acc5, epoch)
        if val_acc > best_acc:
            best_acc = val_acc
            best_acc5 = val_acc5

    logging.info('best accuracy: {:.2f}, accuracy5: {:.2f}'.format(
        best_acc, best_acc5))
예제 #3
0
def main():
    best_acc = 0
    best_classifier = None
    opt = parse_option()

    # build data loader
    train_loader, val_loader = set_loader(opt)

    # build model and criterion
    model, classifier, criterion = set_model(opt)
    best_classifier = classifier

    # build optimizer
    optimizer = set_optimizer(opt, classifier)

    if opt.eval:
        loss, val_acc = validate(val_loader, model, classifier, criterion, opt)
    else:
        # training routine
        for epoch in range(1, opt.epochs + 1):
            adjust_learning_rate(opt, optimizer, epoch)

            # train for one epoch
            time1 = time.time()
            loss, acc = train(train_loader, model, classifier, criterion,
                              optimizer, epoch, opt)
            time2 = time.time()
            print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
                epoch, time2 - time1, acc))

            # eval for one epoch
            loss, val_acc = validate(val_loader, model, classifier, criterion,
                                     opt)
            if val_acc > best_acc:
                best_acc = val_acc
                best_classifier = classifier

        print('best accuracy: {:.2f}'.format(best_acc))

    for epsilon in opt.epsilons:
        loss, acc, adv_acc = adveval(val_loader, model, best_classifier,
                                     criterion, opt, epsilon)
        print('adv accuracy at epsilon {:.2f}: {:.2f}'.format(
            epsilon, adv_acc))