Пример #1
0
                        help='path where the dataset is saved')
    parser.add_argument('--ckpt_path',
                        type=str,
                        default="./checkpoint",
                        help='path where the checkpoint to be saved')
    parser.add_argument('--device_id',
                        type=int,
                        default=0,
                        help='device id of GPU. (Default: 0)')
    args = parser.parse_args()

    device = torch.device('cuda:' + str(args.device_id))
    network = torch.load(args.ckpt_path)
    network.to(device)

    dataloader = create_dataset_pytorch_cifar10(args.data_path, is_train=False)
    with torch.no_grad():
        time_start = time.time()
        total_samples = 0.0
        correct_samples = 0.0
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = network(inputs)
            _, max_index = torch.max(outputs, dim=-1)
            total_samples += labels.size(0)
            correct_samples += (max_index == labels).sum()
        print('Accuracy: {}'.format(correct_samples / total_samples),
              flush=True)
        time_end = time.time()
        time_step = time_end - time_start
Пример #2
0
    device = torch.device('cuda:' + str(args.device_id))
    network = Xception(num_classes=cfg.num_classes)
    network.train()
    network.to(device)
    criterion = nn.CrossEntropyLoss()
    #     optimizer = optim.RMSprop(network.parameters(),
    #                                 lr=cfg.lr_init,
    #                                 eps=cfg.rmsprop_epsilon,
    #                                 momentum=cfg.rmsprop_momentum,
    #                                 alpha=cfg.rmsprop_decay)

    optimizer = optim.SGD(network.parameters(),
                          lr=cfg.lr_init,
                          momentum=cfg.SGD_momentum)
    dataloader = create_dataset_pytorch_cifar10(args.data_path)
    step_per_epoch = len(dataloader)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          gamma=cfg.lr_decay_rate,
                                          step_size=cfg.lr_decay_epoch *
                                          step_per_epoch)
    q_ckpt = Queue(maxsize=cfg.keep_checkpoint_max)

    global_step_id = 0
    #     for epoch in range(cfg.epoch_size):
    for epoch in range(1):
        time_epoch = 0.0
        for i, data in enumerate(dataloader, 0):
            time_start = time.time()
            inputs, labels = data
            inputs = inputs.to(device)