示例#1
0
    epochs = CONFIG['TRAINING']['EPOCH']
    best_top1 = 0.0
    for epoch in range(epochs):
        print("Epoch {}".format(epoch + 1))
        epoch_time = time.time()
        trainer.train(epoch)
        scheduler.step()
        print("LR1/LR2: [{}/{}], Train Time: {:.2f}".format(
            optimizer.param_groups[0]['lr'],
            optimizer.param_groups[1]['lr'],
            time.time() - epoch_time
        ))

        print('-' * 60)

        val_top1 = evaluate.test(topk=(1, 2, 5))
        # trainer.eval(epoch)

        try:
            if val_top1 > best_top1:
                torch.save(model.module.state_dict(), f"{checkpoint}/model_best.pth.tar")
            if (epoch + 1) % 10 == 0:
                torch.save(model.module.state_dict(), f"{checkpoint}/model_{epoch + 1}.pth.tar")
        except AttributeError:
            if val_top1 > best_top1:
                torch.save(model.state_dict(), f"{checkpoint}/model_best.pth.tar")
            if (epoch + 1) % 10 == 0:
                torch.save(model.state_dict(), f"{checkpoint}/model_{epoch + 1}.pth.tar")

        if val_top1 > best_top1: best_top1 = val_top1
示例#2
0
    if with_attribute and not reweighting and False:
        trainer = AttributeTrainer(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            num_attrs=len(attrs),
            save_dir=save_dir,
            optimizer=None,
            summary_writer=None,
            attribute_list=attrs,
            with_attribute=with_attribute,
            num_classes=CONFIG['DATASET']['NUM_CATEGORY'],
            criterion=nn.CrossEntropyLoss(),
            config=CONFIG)

    if torch.cuda.is_available():
        model = model.cuda()

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model.eval()
    #if with_attribute:
    #    trainer.eval(epoch=100)
    evaluate = Evaluation(model=model,
                          dataloader=val_loader,
                          classes=class_names,
                          ten_crops=CONFIG['TESTING']['TEN_CROPS'],
                          with_attribute=CONFIG['MODEL']['WITH_ATTRIBUTE'])
    evaluate.test(topk=(1, 2, 5))