def test(network_save_file): checkpoint = torch.load(network_save_file) net.load_state_dict(checkpoint['net']) training_acc = checkpoint['acc'] training_epoch = checkpoint['epoch'] is_bias = bot.group_by_key(('bias' in k, v) for k, v in bot.trainable_params(net).items()) if args.optimizer == 'sgd': state = {bot.MODEL: net, bot.VALID_MODEL: bot.copy.deepcopy(net), bot.OPTS: [bot.SGD(is_bias[False], opt_params), bot.SGD(is_bias[True], opt_params_bias)]} else: state = {bot.MODEL: net, bot.VALID_MODEL: bot.copy.deepcopy(net), bot.OPTS: [bot.SAM1(is_bias[False], opt_params, sam_flag=True), bot.SAM2(is_bias[True], opt_params_bias, sam_flag=True)]} net.eval() test_loss = 0 correct = 0 total = len(valid_set['data']) with torch.no_grad(): res = bot.reduce(0, valid_batches(batch_size, transforms), state, bot.valid_steps) acc = res['activation_log']['acc'] correct = np.count_nonzero(acc) #print(acc) #print(acc.shape) res2 = bot.epoch_stats(res) print(res2) test_loss = np.mean(res['output']['loss'].detach().cpu().numpy()) progress_bar(len(test_dataset), len(test_dataset), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss, 100.*correct/total, correct, total)) return res2['acc']
'weight_decay': bot.Const(5e-4 * batch_size), 'momentum': bot.Const(0.9) } opt_params_bias = { 'lr': lr_schedule([0, epochs / 5, epochs - ema_epochs], [0.0, 1.0 * 64, 0.1 * 64], batch_size), 'weight_decay': bot.Const(5e-4 * batch_size / 64), 'momentum': bot.Const(0.9) } is_bias = bot.group_by_key( ('bias' in k, v) for k, v in bot.trainable_params(net).items()) state, timer = { bot.MODEL: net, bot.VALID_MODEL: bot.copy.deepcopy(net), bot.OPTS: [ bot.SGD(is_bias[False], opt_params), bot.SGD(is_bias[True], opt_params_bias) ] }, bot.Timer(torch.cuda.synchronize) print('=====> Training...') train_times = [] train_accs = []