def test(net, log=None, batch_size=128, data='cifar100'): """ Test on trained model. :param net: model to be tested :param log: log dir :param batch_size: batch size :param data: datasets used """ net.eval() is_train = False # data if data == 'cifar10': test_loader = load_cifar10(is_train, batch_size) elif data == 'cifar100': test_loader = load_cifar100(is_train, batch_size) elif data == 'svhn': test_loader = load_svhn(is_train, batch_size) elif data == 'mnist': test_loader = load_mnist(is_train, batch_size) elif data == 'tinyimagenet': test_loader = load_tiny_imagenet(is_train, batch_size) else: exit() correct = 0 total = 0 inference_start = time.time() with torch.no_grad(): for i, data in enumerate(test_loader, 0): inputs, labels = data outputs, outputs_conv = net(inputs.cuda()) _, predicted = torch.max(F.softmax(outputs, -1), 1) total += labels.size(0) correct += (predicted == labels.cuda()).sum() inference_time = time.time() - inference_start print('Accuracy: %f %%; Inference time: %fs' % (float(100) * float(correct) / float(total), inference_time)) # print('.', end='') if log != None: log.write('Accuracy of the network on the 10000 test images: %f %%\n' % (float(100) * float(correct) / float(total))) log.write('Inference time is: %fs\n' % inference_time) return inference_time
def train(net, lr, log=None, optimizer_option='SGD', data='cifar100', epochs=350, batch_size=128, is_train=True, net_st=None, beta=0.0, lrd=10): """ Train the model. :param net: model to be trained :param lr: learning rate :param optimizer_option: optimizer type :param data: datasets used to train :param epochs: number of training epochs :param batch_size: batch size :param is_train: whether it is a training process :param net_st: uncompressed model :param beta: transfer parameter """ net.train() if net_st != None: net_st.eval() if data == 'cifar10': trainloader = load_cifar10(is_train, batch_size) valloader = load_cifar10(False, batch_size) elif data == 'cifar100': trainloader = load_cifar100(is_train, batch_size) valloader = load_cifar100(False, batch_size) elif data == 'svhn': trainloader = load_svhn(is_train, batch_size) valloader = load_svhn(False, batch_size) elif data == 'mnist': trainloader = load_mnist(is_train, batch_size) elif data == 'tinyimagenet': trainloader, valloader = load_tiny_imagenet(is_train, batch_size) else: exit() criterion = nn.CrossEntropyLoss() criterion_mse = nn.MSELoss() optimizer = get_optimizer(net, lr, optimizer_option) start_time = time.time() last_time = 0 best_acc = 0 best_param = net.state_dict() iteration = 0 for epoch in range(epochs): print("****************** EPOCH = %d ******************" % epoch) if log != None: log.write("****************** EPOCH = %d ******************\n" % epoch) total = 0 correct = 0 loss_sum = 0 # change learning rate if epoch == 150 or epoch == 250: lr = adjust_lr(lr, lrd=lrd, log=log) optimizer = get_optimizer(net, lr, optimizer_option) for i, data in enumerate(trainloader, 0): iteration += 1 # foward inputs, labels = data inputs_V, labels_V = Variable(inputs.cuda()), Variable( labels.cuda()) outputs, outputs_conv = net(inputs_V) loss = criterion(outputs, labels_V) if net_st != None: outputs_st, outputs_st_conv = net_st(inputs_V) # loss += beta * transfer_loss(outputs_conv, outputs_st_conv) for i in range(len(outputs_st_conv)): # print("!!!!! %d" % i) if i != (len(outputs_st_conv) - 1): loss += beta / 50 * criterion_mse( outputs_conv[i], outputs_st_conv[i].detach()) else: loss += beta * criterion_mse( outputs_conv[i], outputs_st_conv[i].detach()) # backward optimizer.zero_grad() loss.backward() optimizer.step() _, predicted = torch.max(F.softmax(outputs, -1), 1) total += labels_V.size(0) correct += (predicted == labels_V).sum() loss_sum += loss if iteration % 100 == 99: now_time = time.time() print('accuracy: %f %%; loss: %f; time: %ds' % ((float(100) * float(correct) / float(total)), loss, (now_time - last_time))) if log != None: log.write( 'accuracy: %f %%; loss: %f; time: %ds\n' % ((float(100) * float(correct) / float(total)), loss, (now_time - last_time))) total = 0 correct = 0 loss_sum = 0 last_time = now_time # validation if data == 'tinyimagenet': if epoch % 10 == 9: net.eval() val_acc = validation(net, valloader, log) net.train() if val_acc > best_acc: best_acc = val_acc best_param = net.state_dict() else: if epoch % 10 == 9: best_param = net.state_dict() net.eval() validation(net, valloader, log) net.train() print('Finished Training. It took %ds in total' % (time.time() - start_time)) if log != None: log.write('Finished Training. It took %ds in total\n' % (time.time() - start_time)) return best_param