import torch.optim import torch.utils.data from time import perf_counter from tqdm import tqdm from argparse import ArgumentParser from cgae.gs import gumbel_softmax from cgae.cgae import temp_scheduler from cgae.equi import Encoder, Decoder, nearest_assignment import otp parser = ArgumentParser( parents=[otp.otp_parser(), otp.otp_equi_parser()], add_help=True ) args = otp.parse_args(parser) # def autoencoder(args): # pass # # # def evaluate(f, features, geometry, indicies): # with torch.no_grad(): # outs = [] # for i in tqdm(range(0, len(indicies), 50), file=sys.stdout): # sys.stdout.flush() # batch = indicies[i: i + 50] # out = f(features[batch], geometry[batch]) # [batch, atom, xyz]
import torch.utils.data from time import perf_counter from tqdm import tqdm from argparse import ArgumentParser from se3cnn.SO3 import spherical_harmonics_xyz from cgae.gs import gumbel_softmax from cgae.cgae import temp_scheduler import cgae.cgae_dense as dense import cgae.equi as equi import otp parser = ArgumentParser(parents=[otp.otp_parser(), otp.otp_equi_parser()]) parser.add_argument( "--dense", type=str, required=True, help="Pickle dict with 'encoder' and 'decoder' keys.", ) parser.add_argument( "--single_example", action="store_true", help="Test the single example instead." ) se_options = parser.add_mutually_exclusive_group() se_options.add_argument( "--project_one", action="store_true", help="Project an atom onto one cg atom in the single example.", )
import torch.optim import torch.utils.data from time import perf_counter from tqdm import tqdm from argparse import ArgumentParser from cgae.gs import gumbel_softmax from cgae.cgae import temp_scheduler from cgae.equi import Encoder from cgae.cgae_dense import Decoder import otp parser = ArgumentParser(parents=[otp.otp_parser(), otp.otp_equi_parser()], add_help=True) args = otp.parse_args(parser) def execute(args): geometries, forces, features = otp.data(args) encoder = Encoder(args).to(device=args.device) decoder = Decoder(in_dim=args.ncg, out_dim=geometries.size(1)).to(args.device) optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr) temp_sched = temp_scheduler(