def print_predictions(X, fids, parameters, mapping, file_name, models_directory='../models/'): hidden_sizes = parameters.get('hidden_sizes', 200) output_dim = parameters.get('output_dim', -1) model_prefix = parameters.get('model_prefix', None) use_cuda = parameters.get('use_cuda') net = nets.AdvancedNet(X.shape[1], hidden_sizes, output_dim) net.load_state_dict(torch.load(models_directory + model_prefix)) X = Variable(torch.from_numpy(X).float()) if use_cuda: net = net.cuda() X = X.cuda() net.eval() y = net(X) max_index = y.max(dim=1)[1] output_dict = {} for i in range(len(fids)): output_dict[fids[i]] = mapping[max_index[i].item()] with open(file_name, 'w') as fp: json.dump(output_dict, fp, indent=4)
def evaluate_and_get_results(dataloader, parameters): batch_size = parameters.get('batch_size', 100) hidden_sizes = parameters.get('hidden_sizes', 200) input_dim = parameters.get('input_dim') output_dim = parameters.get('output_dim') model_prefix = parameters.get('model_prefix', None) model_suffix = parameters.get('model_suffix', None) use_cuda = parameters.get('use_cuda') # load network parameters['logger'].log('Loading trained model: ' + models_directory + model_prefix + '.' + model_suffix) criterion = torch.nn.CrossEntropyLoss() net = nets.AdvancedNet(input_dim, hidden_sizes, output_dim) if use_cuda: criterion = criterion.cuda() net = net.cuda() net.load_state_dict( torch.load(models_directory + model_prefix + '.' + model_suffix)) else: net.load_state_dict( torch.load(models_directory + model_prefix + '.' + model_suffix, map_location=torch.device('cpu'))) # evaluating num_batches = len(dataloader) total_loss = 0.0 accuracy = 0.0 net.eval() results = np.zeros(parameters['data_num']) pos = 0 parameters['logger'].log('Start testing, batches num: ' + str(num_batches)) for inputs, labels in dataloader: # 填充 nan inputs = torch.where(torch.isnan(inputs), torch.full_like(inputs, 0), inputs) if use_cuda: inputs, labels = inputs.cuda(), labels.cuda() net = net.double() # to fix error: RuntimeError: expected scalar type Float but found Double outputs = net(inputs) if criterion is not None: total_loss += criterion(outputs, labels).item() max_index = outputs.max(dim=1)[1] result = max_index.data.cpu().numpy() accuracy += np.sum( result == labels.data.cpu().numpy() ) \ / inputs.size()[0] results[pos:pos+len(max_index)] = result pos = pos + len(max_index) net.train() accuracy = accuracy / num_batches total_loss = total_loss / num_batches parameters['logger'].log('accuracy: ' + str(accuracy)) parameters['logger'].log('total loss: ' + str(total_loss)) return accuracy, total_loss, results
def evaluate(model_suffix, dataloader, parameters, models_directory='./models/'): batch_size = parameters.get('batch_size', 100) hidden_sizes = parameters.get('hidden_sizes', 200) input_dim = parameters.get('input_dim') output_dim = parameters.get('output_dim') model_prefix = parameters.get('model_prefix', None) criterion = nn.CrossEntropyLoss().cuda() net = nets.AdvancedNet(input_dim, hidden_sizes, output_dim).cuda() net.load_state_dict( torch.load(models_directory + model_prefix + '.' + model_suffix)) accuracy, total_loss = eval_error(net, dataloader, criterion) print('Test acc: %.4f, loss: %.4f' % (accuracy, total_loss))
def train(train_dataloader, val_dataloader, test_dataloader, parameters, models_directory='../models', suffix=''): print('Starting training at ' + util.get_time()) print(', '.join( ['{}={!r}'.format(k, v) for k, v in sorted(parameters.items())])) # batch_size is determined in the dataloader, so the variable is # irrelevant here batch_size = parameters.get('batch_size', 200) num_epochs = parameters.get('num_epochs', 100) hidden_sizes = parameters.get('hidden_sizes', [200]) learning_rate = parameters.get('learning_rate', 0.0005) weight_decay = parameters.get('weight_decay', 0) dropout = parameters.get('dropout', 0.0) patience = parameters.get('patience', 10) threshold = parameters.get('threshold', 1e-3) input_dim = parameters['input_dim'] output_dim = parameters['output_dim'] # output_period: output training loss every x batches output_period = parameters.get('output_period', 0) model_prefix = parameters.get('model_prefix', None) only_train = parameters.get('only_train', False) save_model = parameters.get('save_model', False) test_best = parameters.get('test_best', False) print_test = parameters.get('print_test', False) and (test_dataloader is not None) # we want to print out test accuracies to a separate file test_file = None if print_test: test_file = open('test{}.txt'.format(suffix), 'a') test_file.write('\n\n') test_file.write('Starting at ' + util.get_time() + '\n') test_file.write(', '.join( ['{}={!r}'.format(k, v) for k, v in sorted(parameters.items())]) + '\n\n') # nets and optimizers criterion = nn.CrossEntropyLoss().cuda() net = nets.AdvancedNet(input_dim, hidden_sizes, output_dim, dropout=dropout).cuda() optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay) # ReduceLROnPlateau reduces learning rate by factor of 10 once val loss # has plateaued scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, threshold=threshold) num_train_batches = len(train_dataloader) epoch = 1 best_epoch, best_acc = 0, 0 train_acc = [0] print('starting training') while epoch <= num_epochs: running_loss = 0.0 epoch_acc = 0.0 net.train() print('epoch: %d, lr: %.1e' % (epoch, optimizer.param_groups[0]['lr']) + ' ' + util.get_time()) for batch_num, (inputs, labels) in enumerate(train_dataloader, 1): optimizer.zero_grad() inputs, labels = inputs.cuda(), labels.cuda() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() # output is 2D array (logsoftmax output), so we flatten it to a 1D to get the max index for each example # and then calculate accuracy off that max_index = outputs.max(dim=1)[1] epoch_acc += np.sum(max_index.data.cpu().numpy() == labels.data.cpu().numpy()) / inputs.size()[0] # output every output_period batches if output_period: if batch_num % output_period == 0: print('[%d:%.2f] loss: %.3f' % (epoch, batch_num * 1.0 / num_train_batches, running_loss / output_period)) running_loss = 0.0 gc.collect() # save model after every epoch in models/ folder if save_model: torch.save(net.state_dict(), models_directory + '/' + model_prefix + ".%d" % epoch) # print training/val accuracy epoch_acc = epoch_acc / num_train_batches train_acc.append(epoch_acc) print('train acc: %.4f' % (epoch_acc)) if only_train: scheduler.step(loss) else: val_accuracy, total_loss = evaluate.eval_error( net, val_dataloader, criterion) print('val acc: %.4f, loss: %.4f' % (val_accuracy, total_loss)) # remember: feed val loss into scheduler scheduler.step(total_loss) if val_accuracy > best_acc: best_epoch, best_acc = epoch, val_accuracy print() # write test accuracy if print_test: test_accuracy, total_loss = evaluate.eval_error( net, test_dataloader, criterion) test_file.write('epoch: %d' % (epoch) + ' ' + util.get_time() + '\n') test_file.write('train acc: %.4f' % (epoch_acc) + '\n') test_file.write('val acc: %.4f' % (val_accuracy) + '\n') test_file.write('test acc: %.4f' % (test_accuracy) + '\n') test_file.write('loss: %.4f' % (test_accuracy) + '\n') gc.collect() # perform early stopping here if our learning rate is below a threshold # because small lr means little change in accuracy anyways if optimizer.param_groups[0]['lr'] < (0.9 * 0.01 * learning_rate): print('Low LR reached, finishing training early') break epoch += 1 print('best epoch: %d' % best_epoch) print('best val accuracy: %.4f' % best_acc) print('train accuracy at that epoch: %.4f' % train_acc[best_epoch]) print('ending at', time.ctime()) if test_best: net.load_state_dict( torch.load(models_directory + '/' + model_prefix + '.' + str(best_epoch))) best_test_accuracy, total_loss = evaluate.eval_error( net, test_dataloader, criterion) test_file.write('*****\n') test_file.write('best test acc: %.4f, loss: %.4f' % (best_test_accuracy, total_loss) + '\n') test_file.write('*****\n') print('best test acc: %.4f, loss: %.4f' % (best_test_accuracy, total_loss)) if print_test: test_file.write('\n') test_file.close() print('\n\n\n')