def validate(): # run validation mse_sum = torch.nn.MSELoss(reduction='sum') total_mse = 0.0 count = 0 for properties in validation: species = properties['species'].to(device) coordinates = properties['coordinates'].to(device).float() true_energies = properties['energies'].to(device).float() _, predicted_energies = model((species, coordinates)) total_mse += mse_sum(predicted_energies, true_energies).item() count += predicted_energies.shape[0] return hartree2kcalmol(math.sqrt(total_mse / count))
def evaluate(self, dataset): """Run the evaluation""" total_mse = 0.0 count = 0 for properties in dataset: species = properties['species'].to(self.device) coordinates = properties['coordinates'].to(self.device).float() true_energies = properties['energies'].to(self.device).float() _, predicted_energies = self.model((species, coordinates)) total_mse += self.mse_sum(predicted_energies, true_energies).item() count += predicted_energies.shape[0] return hartree2kcalmol(math.sqrt(total_mse / count))
def do_benchmark(model): dataset = recursive_h5_files(parser.dir) mae_averager_energy = Averager() mae_averager_relative_energy = Averager() mae_averager_force = Averager() rmse_averager_energy = Averager() rmse_averager_relative_energy = Averager() rmse_averager_force = Averager() for i in tqdm.tqdm(dataset, position=0, desc="dataset"): # read coordinates = torch.tensor(i['coordinates'], device=parser.device) species = model.species_to_tensor(i['species']) \ .unsqueeze(0).expand(coordinates.shape[0], -1) energies = torch.tensor(i['energies'], device=parser.device) forces = torch.tensor(i['forces'], device=parser.device) # compute energies2, forces2 = by_batch(species, coordinates, model) ediff = energies - energies2 relative_ediff = relative_energies(energies) - \ relative_energies(energies2) fdiff = forces.flatten() - forces2.flatten() # update mae_averager_energy.update(ediff.abs()) mae_averager_relative_energy.update(relative_ediff.abs()) mae_averager_force.update(fdiff.abs()) rmse_averager_energy.update(ediff**2) rmse_averager_relative_energy.update(relative_ediff**2) rmse_averager_force.update(fdiff**2) mae_energy = hartree2kcalmol(mae_averager_energy.compute()) rmse_energy = hartree2kcalmol(math.sqrt(rmse_averager_energy.compute())) mae_relative_energy = hartree2kcalmol( mae_averager_relative_energy.compute()) rmse_relative_energy = hartree2kcalmol( math.sqrt(rmse_averager_relative_energy.compute())) mae_force = hartree2kcalmol(mae_averager_force.compute()) rmse_force = hartree2kcalmol(math.sqrt(rmse_averager_force.compute())) print("Energy:", mae_energy, rmse_energy) print("Relative Energy:", mae_relative_energy, rmse_relative_energy) print("Forces:", mae_force, rmse_force)
if PROFILING_STARTED: torch.cuda.nvtx.range_push( "batch{}".format(total_batch_counter)) species = properties['species'].to(parser.device) coordinates = properties['coordinates'].to(parser.device).float() true_energies = properties['energies'].to(parser.device).float() num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True): _, predicted_energies = model((species, coordinates)) loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() rmse = hartree2kcalmol( (mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() if PROFILING_STARTED: torch.cuda.nvtx.range_push("backward") with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True): loss.backward() if PROFILING_STARTED: torch.cuda.nvtx.range_pop() if PROFILING_STARTED: torch.cuda.nvtx.range_push("optimizer.step()") with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True): optimizer.step()
def benchmark(parser, dataset, use_cuda_extension, force_inference=False): synchronize = True timers = {} def time_func(key, func): timers[key] = 0 def wrapper(*args, **kwargs): start = timeit.default_timer() ret = func(*args, **kwargs) sync_cuda(synchronize) end = timeit.default_timer() timers[key] += end - start return ret return wrapper Rcr = 5.2000e+00 Rca = 3.5000e+00 EtaR = torch.tensor([1.6000000e+01], device=parser.device) ShfR = torch.tensor([ 9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00 ], device=parser.device) Zeta = torch.tensor([3.2000000e+01], device=parser.device) ShfZ = torch.tensor([ 1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00 ], device=parser.device) EtaA = torch.tensor([8.0000000e+00], device=parser.device) ShfA = torch.tensor( [9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=parser.device) num_species = 4 aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species, use_cuda_extension) nn = torchani.ANIModel(build_network()) model = torch.nn.Sequential(aev_computer, nn).to(parser.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.000001) mse = torch.nn.MSELoss(reduction='none') # enable timers torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine', torchani.aev.cutoff_cosine) torchani.aev.radial_terms = time_func('torchani.aev.radial_terms', torchani.aev.radial_terms) torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms) torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts) torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs) torchani.aev.neighbor_pairs_nopbc = time_func( 'torchani.aev.neighbor_pairs_nopbc', torchani.aev.neighbor_pairs_nopbc) torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index) torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero) torchani.aev.triple_by_molecule = time_func( 'torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule) torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev) model[0].forward = time_func('total', model[0].forward) model[1].forward = time_func('forward', model[1].forward) optimizer.step = time_func('optimizer.step', optimizer.step) print('=> start training') start = time.time() loss_time = 0 force_time = 0 for epoch in range(0, parser.num_epochs): print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs)) progbar = pkbar.Kbar(target=len(dataset) - 1, width=8) for i, properties in enumerate(dataset): species = properties['species'].to(parser.device) coordinates = properties['coordinates'].to( parser.device).float().requires_grad_(force_inference) true_energies = properties['energies'].to(parser.device).float() num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype) _, predicted_energies = model((species, coordinates)) # TODO add sync after aev is done sync_cuda(synchronize) energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() if force_inference: sync_cuda(synchronize) force_coefficient = 0.1 true_forces = properties['forces'].to(parser.device).float() force_start = time.time() try: sync_cuda(synchronize) forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0] sync_cuda(synchronize) except Exception as e: alert('Error: {}'.format(e)) return force_time += time.time() - force_start force_loss = (mse(true_forces, forces).sum(dim=(1, 2)) / num_atoms).mean() loss = energy_loss + force_coefficient * force_loss sync_cuda(synchronize) else: loss = energy_loss rmse = hartree2kcalmol( (mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() progbar.update(i, values=[("rmse", rmse)]) if not force_inference: sync_cuda(synchronize) loss_start = time.time() loss.backward() # print('2', coordinates.grad) sync_cuda(synchronize) loss_stop = time.time() loss_time += loss_stop - loss_start optimizer.step() sync_cuda(synchronize) checkgpu() sync_cuda(synchronize) stop = time.time() print('=> More detail about benchmark PER EPOCH') total_time = (stop - start) / parser.num_epochs loss_time = loss_time / parser.num_epochs force_time = force_time / parser.num_epochs opti_time = timers['optimizer.step'] / parser.num_epochs forward_time = timers['forward'] / parser.num_epochs aev_time = timers['total'] / parser.num_epochs print_timer(' Total AEV', aev_time) print_timer(' Forward', forward_time) print_timer(' Backward', loss_time) print_timer(' Force', force_time) print_timer(' Optimizer', opti_time) print_timer( ' Others', total_time - loss_time - aev_time - forward_time - opti_time - force_time) print_timer(' Epoch time', total_time)