Ejemplo n.º 1
0
def test(network_save_file):

    checkpoint = torch.load(network_save_file)
    net.load_state_dict(checkpoint['net'])
    training_acc = checkpoint['acc']
    training_epoch = checkpoint['epoch']

    is_bias = bot.group_by_key(('bias' in k, v) for k, v in bot.trainable_params(net).items())

    if args.optimizer == 'sgd':
        state = {bot.MODEL: net, bot.VALID_MODEL: bot.copy.deepcopy(net), bot.OPTS: [bot.SGD(is_bias[False], opt_params), bot.SGD(is_bias[True], opt_params_bias)]}
    else:
        state = {bot.MODEL: net, bot.VALID_MODEL: bot.copy.deepcopy(net), bot.OPTS: [bot.SAM1(is_bias[False], opt_params, sam_flag=True), bot.SAM2(is_bias[True], opt_params_bias, sam_flag=True)]}

    net.eval()
    test_loss = 0
    correct = 0
    total = len(valid_set['data'])
  
    with torch.no_grad():
        
        res = bot.reduce(0, valid_batches(batch_size, transforms), state, bot.valid_steps)
        acc = res['activation_log']['acc']
        correct = np.count_nonzero(acc)
        #print(acc)
        #print(acc.shape)
        res2 = bot.epoch_stats(res)
        print(res2)

        test_loss = np.mean(res['output']['loss'].detach().cpu().numpy())
        progress_bar(len(test_dataset), len(test_dataset), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (test_loss, 100.*correct/total, correct, total))
    return res2['acc'] 
Ejemplo n.º 2
0
    'weight_decay':
    bot.Const(5e-4 * batch_size),
    'momentum':
    bot.Const(0.9)
}
opt_params_bias = {
    'lr':
    lr_schedule([0, epochs / 5, epochs - ema_epochs],
                [0.0, 1.0 * 64, 0.1 * 64], batch_size),
    'weight_decay':
    bot.Const(5e-4 * batch_size / 64),
    'momentum':
    bot.Const(0.9)
}

is_bias = bot.group_by_key(
    ('bias' in k, v) for k, v in bot.trainable_params(net).items())
state, timer = {
    bot.MODEL:
    net,
    bot.VALID_MODEL:
    bot.copy.deepcopy(net),
    bot.OPTS: [
        bot.SGD(is_bias[False], opt_params),
        bot.SGD(is_bias[True], opt_params_bias)
    ]
}, bot.Timer(torch.cuda.synchronize)

print('=====> Training...')

train_times = []
train_accs = []