def execute(args): geometries, forces, features = otp.data(args) if args.cg_ones: cg_features = torch.ones( args.bs, args.ncg, 1, dtype=args.precision, device=args.device ) else: cg_features = torch.zeros( args.bs, args.ncg, args.ncg, dtype=args.precision, device=args.device ) cg_features.scatter_( -1, torch.arange(args.ncg, device=args.device) .expand(args.bs, args.ncg) .unsqueeze(-1), 1.0, ) encoder = Encoder(args).to(device=args.device) decoder = Decoder(args).to(device=args.device) optimizer = torch.optim.Adam( list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr ) temp_sched = temp_scheduler( args.epochs, args.tdr, args.temp, args.tmin, dtype=args.precision, device=args.device, ) n_batches, geometries, forces, features = otp.batch( geometries, forces, features, args.bs ) dynamics = [] summaries = [] wall_start = perf_counter() torch.manual_seed(args.seed) for epoch in tqdm(range(args.epochs)): for step, batch in tqdm( enumerate(torch.randperm(n_batches, device=args.device)) ): feat, geo, force = features[batch], geometries[batch], forces[batch] # Auto encoder logits = encoder(feat, geo) cg_assign, st_cg_assign = gumbel_softmax( logits, temp_sched[epoch], dtype=args.precision, device=args.device ) E = cg_assign / cg_assign.sum(1).unsqueeze(1) cg_xyz = torch.einsum("zij,zik->zjk", E, geo) # End goal is projection of atoms by atomic number onto coarse grained atom. relative_xyz = ( geo.unsqueeze(2).cpu().detach() - cg_xyz.unsqueeze(1).cpu().detach() ) nearest_assign = nearest_assignment(cg_xyz, geo) if args.gumble_sm_proj: cg_proj = otp.project_onto_cg(relative_xyz, cg_assign, feat, args) elif args.nearest: cg_proj = otp.project_onto_cg(relative_xyz, nearest_assign, feat, args) else: cg_proj = otp.project_onto_cg(relative_xyz, st_cg_assign, feat, args) pred_sph = decoder(cg_features, cg_xyz.clone().detach()) cg_proj = cg_proj.reshape_as(pred_sph) loss_ae = (cg_proj - pred_sph).pow(2).sum(-1).div(args.atomic_nums).mean() if args.fm and epoch >= args.fm_epoch: # Force matching cg_force_assign, _ = gumbel_softmax( logits, temp_sched[epoch] * args.force_temp_coeff, device=args.device, dtype=args.precision, ) cg_force = torch.einsum("zij,zik->zjk", cg_force_assign, force) loss_fm = cg_force.pow(2).sum(-1).mean() loss = loss_ae + args.fm_co * loss_fm else: loss_fm = torch.tensor(0) loss = loss_ae dynamics.append( { "loss_ae": loss_ae.item(), "loss_fm": loss_fm.item(), "loss": loss.item(), "epoch": epoch, "step": step, "batch": batch.item(), } ) optimizer.zero_grad() loss.backward() optimizer.step() wall = perf_counter() - wall_start if wall > args.wall: break summaries.append( { "loss_ae": loss_ae.item(), "loss_fm": loss_fm.item(), "loss": loss.item(), "epoch": epoch, "step": step, "batch": batch.item(), "cg_xyz": cg_xyz, "pred_sph": pred_sph, "sph": cg_proj, "temp": temp_sched[epoch].item(), "gumble": cg_assign, "st_gumble": st_cg_assign, "nearest": nearest_assign, } ) return { "args": args, "dynamics": dynamics, "summaries": summaries, # 'train': { # 'pred': evaluate(f, features, geometry, train[:len(test)]), # 'true': forces[train[:len(test)]], # }, # 'test': { # 'pred': evaluate(f, features, geometry, test[:len(train)]), # 'true': forces[test[:len(train)]], # }, "encoder": encoder.state_dict() if args.save_state else None, "decoder": decoder.state_dict() if args.save_state else None, }
def execute(args): dense_dict = torch.load(args.dense, map_location=args.device) geometries, forces, features = otp.data(args) encoder_dense = dense.Encoder( in_dim=geometries.size(1), out_dim=args.ncg, device=args.device ).to(args.device) encoder_dense.load_state_dict(dense_dict["encoder"]) encoder_dense.weight1.detach_() decoder_dense = dense.Decoder(in_dim=args.ncg, out_dim=geometries.size(1)).to( args.device ) decoder_dense.load_state_dict(dense_dict["decoder"]) decoder_dense.weight.detach_() if args.cg_ones: Rs_in = [[(1, 0)]] elif args.project_one: Rs_in = [[(1, 0), (1, 2)]] elif args.soln: Rs_in = [[(1, l) for mul, l in enumerate(range(args.proj_lmax + 1))] * 2] elif args.cg_specific_atom: raise NotImplementedError() # Rs_in = [[(1, 0), (1, args.cg_specific_atom)]] else: Rs_in = [[(args.ncg, 0)]] # Encoder... TBD decoder = equi.Decoder(args, Rs_in).to(device=args.device) # optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr) optimizer = torch.optim.Adam(list(decoder.parameters()), lr=args.lr) temp_sched = temp_scheduler( args.epochs, args.tdr, args.temp, args.tmin, dtype=args.precision, device=args.device, ) n_batches, geometries, forces, features = otp.batch( geometries, forces, features, args.bs ) dynamics = [] summaries = [] wall_start = perf_counter() torch.manual_seed(args.seed) for epoch in tqdm(range(args.epochs)): # for step, batch in tqdm(enumerate(torch.randperm(n_batches, device=args.device))): for step, batch in tqdm( enumerate([torch.tensor(0, device=args.device)] * n_batches) ): feat, geo, force = features[batch], geometries[batch], forces[batch] # Auto encoder cg_xyz = encoder_dense(geo, temp_sched[epoch]) logits = encoder_dense.weight1.t() cg_assign, st_cg_assign = gumbel_softmax( logits, temp_sched[epoch], dtype=args.precision, device=args.device ) decoded = decoder_dense(cg_xyz) loss_ae_dense = (decoded - geo).pow(2).sum(-1).mean() # Calculate Ground Truth # End goal is projection of atoms by atomic number onto coarse grained atom. relative_xyz = ( geo.unsqueeze(2).cpu().detach() - cg_xyz.unsqueeze(1).cpu().detach() ) nearest_assign = equi.nearest_assignment(cg_xyz, geo) if args.gumble_sm_proj: cg_proj = otp.project_onto_cg(relative_xyz, cg_assign, feat, args) elif args.nearest: cg_proj = otp.project_onto_cg(relative_xyz, nearest_assign, feat, args) else: cg_proj = otp.project_onto_cg(relative_xyz, st_cg_assign, feat, args) cg_proj = cg_proj.reshape(args.bs, args.ncg, -1) # Features and Predict if args.cg_ones: cg_features = torch.ones( args.bs, args.ncg, 1, dtype=args.precision, device=args.device ) elif args.project_one: # The purpose is to select the nearest atom to a cg_atom, project it, give a single atom that feature. relative_xyz = geo.unsqueeze(2) - cg_xyz.unsqueeze(1) nearest_atom_ind = relative_xyz.norm(dim=-1).argmin(1).squeeze() cg_atom_ind = 2 l2_features = otp.project_atom_onto_cg_features( relative_xyz, 2, nearest_atom_ind[cg_atom_ind], cg_atom_ind, dtype=args.precision, device=args.device, ) l1_features = torch.ones( *l2_features.shape[:2], 1, device=args.device, dtype=args.precision ) cg_features = torch.cat([l1_features, l2_features], dim=-1) elif args.soln: cg_features = cg_proj.clone() # Give solution elif args.cg_specific_atom: raise NotImplementedError() # # The purpose is to select the nearest atom to a cg_atom, project it, give a single atom that feature. # relative_xyz = geo.unsqueeze(2) - cg_xyz.unsqueeze(1) # nearest_atom_ind = relative_xyz.norm(dim=-1).argmin(1).squeeze() # cg_atom_ind = 2 # l2_features = otp.project_atom_onto_cg_features(relative_xyz, args.cg_specific_atom, nearest_atom_ind[cg_atom_ind], # cg_atom_ind, dtype=args.precision, device=args.device) # l1_features = torch.ones( # *l2_features.shape[:2], 1, device=args.device, dtype=args.precision # ) # cg_features = torch.cat([l1_features, l2_features], dim=-1) else: cg_features = torch.zeros( args.bs, args.ncg, args.ncg, dtype=args.precision, device=args.device, ) cg_features.scatter_( -1, torch.arange(args.ncg, device=args.device) .expand(args.bs, args.ncg) .unsqueeze(-1), 1.0, ) pred_sph = decoder(cg_features, cg_xyz.clone().detach()) # Loss loss_ae_equi = ( (cg_proj - pred_sph).pow(2).sum(-1).div(args.atomic_nums).mean() ) if args.fm and epoch >= args.fm_epoch: # Force matching cg_force_assign, _ = gumbel_softmax( logits, temp_sched[epoch] * args.force_temp_coeff, device=args.device, dtype=args.precision, ) cg_force = torch.einsum("...ij,zik->zjk", cg_force_assign, force) loss_fm = cg_force.pow(2).sum(-1).mean() # loss = loss_ae_equi + loss_ae_dense + args.fm_co * loss_fm else: loss_fm = torch.tensor(0) # loss = loss_ae_equi + loss_ae_dense loss = loss_ae_equi dynamics.append( { "loss_ae_equi": loss_ae_equi.item(), "loss_ae_dense": loss_ae_dense.item(), "loss_fm": loss_fm.item(), "loss": loss.item(), "epoch": epoch, "step": step, "batch": batch.item(), } ) optimizer.zero_grad() loss.backward() optimizer.step() wall = perf_counter() - wall_start if wall > args.wall: break summaries.append( { "loss_ae_equi": loss_ae_equi.item(), "loss_ae_dense": loss_ae_dense.item(), "loss_fm": loss_fm.item(), "loss": loss.item(), "epoch": epoch, "step": step, "batch": batch.item(), "cg_xyz": cg_xyz, "pred_sph": pred_sph, "sph": cg_proj, "temp": temp_sched[epoch].item(), "gumble": cg_assign, "st_gumble": st_cg_assign, "nearest": nearest_assign, } ) return { "args": args, "dynamics": dynamics, "summaries": summaries, # 'train': { # 'pred': evaluate(f, features, geometry, train[:len(test)]), # 'true': forces[train[:len(test)]], # }, # 'test': { # 'pred': evaluate(f, features, geometry, test[:len(train)]), # 'true': forces[test[:len(train)]], # }, # 'encoder': encoder.state_dict() if args.save_state else None, "decoder": decoder.state_dict() if args.save_state else None, }
def execute(args): geometries, forces, features = otp.data(args) encoder = Encoder(args).to(device=args.device) decoder = Decoder(in_dim=args.ncg, out_dim=geometries.size(1)).to(args.device) optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr) temp_sched = temp_scheduler( args.epochs, args.tdr, args.temp, args.tmin, dtype=args.precision, device=args.device, ) n_batches, geometries, forces, features = otp.batch( geometries, forces, features, args.bs) dynamics = [] summaries = [] wall_start = perf_counter() torch.manual_seed(args.seed) for epoch in tqdm(range(args.epochs)): for step, batch in tqdm( enumerate(torch.randperm(n_batches, device=args.device))): feat, geo, force = features[batch], geometries[batch], forces[ batch] # Auto encoder logits = encoder(feat, geo) cg_assign, st_cg_assign = gumbel_softmax(logits, temp_sched[epoch], dtype=args.precision, device=args.device) E = cg_assign / cg_assign.sum(1).unsqueeze(1) cg_xyz = torch.einsum("zij,zik->zjk", E, geo) decoded = decoder(cg_xyz) loss_ae = (decoded - geo).pow(2).sum(-1).mean() if args.fm and epoch >= args.fm_epoch: # Force matching cg_force_assign, _ = gumbel_softmax( logits, temp_sched[epoch] * args.force_temp_coeff, device=args.device, dtype=args.precision, ) cg_force = torch.einsum("zij,zik->zjk", cg_force_assign, force) loss_fm = cg_force.pow(2).sum(-1).mean() loss = loss_ae + args.fm_co * loss_fm else: loss_fm = torch.tensor(0) loss = loss_ae dynamics.append({ "loss_ae": loss_ae.item(), "loss_fm": loss_fm.item(), "loss": loss.item(), "epoch": epoch, "step": step, "batch": batch.item(), }) optimizer.zero_grad() loss.backward() optimizer.step() wall = perf_counter() - wall_start if wall > args.wall: break summaries.append({ "loss_ae": loss_ae.item(), "loss_fm": loss_fm.item(), "loss": loss.item(), "epoch": epoch, "step": step, "batch": batch.item(), "cg_xyz": cg_xyz, "temp": temp_sched[epoch].item(), "gumble": cg_assign, "st_gumble": st_cg_assign, "reconstructed": decoded, }) return { "args": args, "dynamics": dynamics, "summaries": summaries, # 'train': { # 'pred': evaluate(f, features, geometry, train[:len(test)]), # 'true': forces[train[:len(test)]], # }, # 'test': { # 'pred': evaluate(f, features, geometry, test[:len(train)]), # 'true': forces[test[:len(train)]], # }, "encoder": encoder.state_dict() if args.save_state else None, "decoder": decoder.state_dict() if args.save_state else None, }