def test_copied_examples(): fname = datadir + "/ligonly.types" e = molgrid.ExampleProvider(data_root=datadir + "/structs") e.populate(fname) batch_size = 10 b = e.next_batch(batch_size) for i in range(1, batch_size): sqsum = np.square(b[0].coord_sets[1].coords.tonumpy() - b[i].coord_sets[1].coords.tonumpy()).sum() assert sqsum > 0 #now with duplicates e = molgrid.ExampleProvider(data_root=datadir + "/structs", num_copies=batch_size) e.populate(fname) b = e.next_batch(batch_size) for i in range(1, batch_size): sqsum = np.square(b[0].coord_sets[1].coords.tonumpy() - b[i].coord_sets[1].coords.tonumpy()).sum() assert sqsum == 0 #transforming one of the duplicates should not effect the others orig = b[0].coord_sets[1].coords.tonumpy() lig = b[1].coord_sets[1] t = molgrid.Transform(lig.center(), 2, random_rotation=True) t.forward(lig, lig) new0 = b[0].coord_sets[1].coords.tonumpy() new1 = b[1].coord_sets[1].coords.tonumpy() np.testing.assert_allclose(orig, new0) sqsum = np.square(new1 - orig).sum() assert sqsum > 0
def test_backwards(): g1 = molgrid.GridMaker(resolution=.1, dimension=6.0) c = np.array([[1.0, 0, 0]], np.float32) t = np.array([0], np.float32) r = np.array([2.0], np.float32) coords = molgrid.CoordinateSet(molgrid.Grid2f(c), molgrid.Grid1f(t), molgrid.Grid1f(r), 1) shape = g1.grid_dimensions(1) #make diff with gradient in center diff = molgrid.MGrid4f(*shape) diff[0, 30, 30, 30] = 1.0 cpuatoms = molgrid.MGrid2f(1, 3) gpuatoms = molgrid.MGrid2f(1, 3) #apply random rotation T = molgrid.Transform((0, 0, 0), 0, True) T.forward(coords, coords) g1.backward((0, 0, 0), coords, diff.cpu(), cpuatoms.cpu()) g1.backward((0, 0, 0), coords, diff.gpu(), gpuatoms.gpu()) T.backward(cpuatoms.cpu(), cpuatoms.cpu(), False) T.backward(gpuatoms.gpu(), gpuatoms.gpu(), False) print(cpuatoms.tonumpy(), gpuatoms.tonumpy()) # results should be ~ -.6, 0, 0 np.testing.assert_allclose(cpuatoms.tonumpy(), gpuatoms.tonumpy(), atol=1e-5) np.testing.assert_allclose(cpuatoms.tonumpy().flatten(), [-0.60653067, 0, 0], atol=1e-5)
def validate(val_loader, model, criterion, args): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') r = AverageMeter('Pearson R', ':6.2f') rmse = AverageMeter('RMSE', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, r, rmse], prefix='Test: ') # switch to evaluate mode model.eval() predictions = [] targets = [] with torch.no_grad(): end = time.time() for i, (lengths, center, coords, types, radii, labels) in enumerate(train_loader): types = types.cuda(args.gpu, non_blocking=True) radii = radii.squeeze().cuda(args.gpu, non_blocking=True) coords = coords.cuda(args.gpu, non_blocking=True) batch_size = coords.shape[0] if batch_size != types.shape[0] or batch_size != radii.shape[0]: raise RuntimeError("Inconsistent batch sizes in dataset outputs") output1 = torch.empty(batch_size,*tensorshape,dtype=coords.dtype,device=coords.device) for idx in range(batch_size): t = molgrid.Transform(molgrid.float3(*(center[idx].numpy().tolist())),random_translate=2,random_rotation=True) t.forward(coords[idx][:lengths[idx]],coords_q[idx][:lengths[idx]]) gmaker.forward(t.get_rotation_center(), coords_q[idx][:lengths[idx]], types[idx][:lengths[idx]], radii[idx][:lengths[idx]], molgrid.tensor_as_grid(output1[idx])) del lengths, center, coords, types, radii torch.cuda.empty_cache() target = labels.cuda(args.gpu, non_blocking=True) # compute output prediction = model(output1) loss = criterion(prediction, target) # measure accuracy and record loss r_val, rmse_val = accuracy(prediction, target) losses.update(loss.item(), output1.size(0)) r.update(r_val, output1.size(0)) rmse.update(rmse_val, output1.size(0)) predictions += prediction.detach().flatten().tolist() targets += target.detach().flatten().tolist() # measure elapsed time batch_time.update(time.time() - end) end = time.time() r_avg, rmse_avg = accuracy(predictions,target) return r_avg, rmse_avg
def test_coordset_from_mol(): m = pybel.readstring('smi','c1ccccc1CO') m.addh() m.make3D() c = molgrid.CoordinateSet(m,molgrid.ElementIndexTyper()) oldcoord = c.coords.tonumpy() #simple translate t = molgrid.Transform(molgrid.Quaternion(), (0,0,0), (1,1,1)) t.forward(c,c) newcoord = c.coords.tonumpy() assert np.sum(newcoord-oldcoord) == approx(48)
def forward(self, transforms, **kwargs): # just interpolate the centers for now centers = torch.tensor( [tuple(t.get_rotation_center()) for t in transforms], dtype=float) centers = super().forward(centers, **kwargs) return [ molgrid.Transform( t.get_quaternion(), tuple(center.numpy()), t.get_translation(), ) for t, center in zip(transforms, centers) ]
def test_coordset_from_array(): coords = np.array([[1,0,-1],[1,3,-1],[1,0,-1]],np.float32) types = np.array([3,2,1],np.float32) radii = np.array([1.5,1.5,1.0],np.float32) c = molgrid.CoordinateSet(coords, types, radii, 4) oldcoordr = c.coords.tonumpy() #simple translate t = molgrid.Transform(molgrid.Quaternion(), (0,0,0), (-1,0,1)) t.forward(c,c) newcoord = c.coords.tonumpy() assert c.coords[1,1] == 3.0 assert np.sum(newcoord) == approx(3.0) c2 = c.clone() c2.coords[1,1] = 0 assert c.coords[1,1] == 3.0
obmol.addh() print(obmol, end="") # Use OpenBabel molecule object (obmol.OBmol) instead of PyBel molecule (obmol) cs = molgrid.CoordinateSet(obmol.OBMol, t) ex = molgrid.Example() ex.coord_sets.append(cs) c = ex.coord_sets[0].center() # Only one coordinate set print("center:", tuple(c)) # https://gnina.github.io/libmolgrid/python/index.html#the-transform-class transform = molgrid.Transform( c, random_translate=0.0, random_rotation=False, # float # bool ) transform.forward(ex, ex) # Compute grid gm.forward(ex, grid[0]) print("grid.shape:", grid.shape) if args.dx: # https://gnina.github.io/libmolgrid/python/index.html#molgrid.write_dx_grids # Grid4f is different from Grid4fCUDA # If a function takes Grid4f as input, torch.Tensor need to be moved to the CPU molgrid.write_dx_grids( f"grids/{system}",
def train(train_loader, model, criterion, optimizer, gmaker, tensorshape, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') slosses = AverageMeter('SuperLoss', ':.4e') # top1 = AverageMeter('Acc@1', ':6.2f') # top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, slosses], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() total_loss = 0 super_loss = 0 for i, (lengths, center, coords, types, radii, afflabel) in enumerate(train_loader): deltaG = afflabel.cuda(args.gpu, non_blocking=True) types = types.cuda(args.gpu, non_blocking=True) radii = radii.squeeze().cuda(args.gpu, non_blocking=True) coords = coords.cuda(args.gpu, non_blocking=True) coords_q = torch.empty(*coords.shape, device=coords.device, dtype=coords.dtype) batch_size = coords.shape[0] if i == 0: print(batch_size) if batch_size != types.shape[0] or batch_size != radii.shape[0]: raise RuntimeError("Inconsistent batch sizes in dataset outputs") output1 = torch.empty(batch_size, *tensorshape, dtype=coords.dtype, device=coords.device) output2 = torch.empty(batch_size, *tensorshape, dtype=coords.dtype, device=coords.device) for idx in range(batch_size): t = molgrid.Transform( molgrid.float3(*(center[idx].numpy().tolist())), random_translate=2, random_rotation=True) t.forward(coords[idx][:lengths[idx]], coords_q[idx][:lengths[idx]]) gmaker.forward(t.get_rotation_center(), coords_q[idx][:lengths[idx]], types[idx][:lengths[idx]], radii[idx][:lengths[idx]], molgrid.tensor_as_grid(output1[idx])) t.forward(coords[idx][:lengths[idx]], coords[idx][:lengths[idx]]) gmaker.forward(t.get_rotation_center(), coords[idx][:lengths[idx]], types[idx][:lengths[idx]], radii[idx][:lengths[idx]], molgrid.tensor_as_grid(output2[idx])) # measure data loading time data_time.update(time.time() - end) # compute output output, target, preds = model(im_q=output1, im_k=output2) loss = criterion(output, target) if args.semi_super: if i == 0: print(preds[:10]) print(deltaG[:10]) lossmask = deltaG.gt(0) sloss = torch.sum(lossmask * nn.functional.mse_loss( preds, deltaG, reduction='none')) / lossmask.sum() super_loss += sloss.item() loss += sloss slosses.update(sloss.item(), lossmask.sum()) total_loss += loss.item() # acc1/acc5 are (K+1)-way contrast classifier accuracy # measure accuracy and record loss # acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), output1.size(0)) # top1.update(acc1[0], images[0].size(0)) # top5.update(acc5[0], images[0].size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() if args.semi_super: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if args.semi_super: wandb.log({"Supervised Loss": super_loss / len(train_loader)}, commit=False) wandb.log({"Total Loss": total_loss / len(train_loader)})
def train_and_test(args,model,eptrain,eptest,gmaker): def test_model(model,ep,gmaker,percent_reduced,batch_size): #loss accumulation total_loss_mean = [] rmsd_loss_mean = [] atomic_grads_loss_mean=[] #testing setup #testing loop for j in range(int((percent_reduced/100) * ep.size())): batch = ep.next_batch(batch_size) disp_vecs = [] batch.extract_label(0, float_labels) labels = float_labels.to('cuda') for b in range(batch_size): try: disp_vecs.append(batch[b].coord_sets[0].coords.tonumpy() - batch[b].coord_sets[2].coords.tonumpy()) disp_vecs[-1]=torch.from_numpy(disp_vecs[-1]).cuda() batch[b].coord_sets.__delitem__(0) except: print(batch[b].coord_sets[0].coords.tonumpy().shape,batch[b].coord_sets[2].coords.tonumpy().shape) disp_vecs.append(torch.zeros(batch[b].coord_sets[2].coords.tonumpy().shape,dtype=torch.float32).cuda()) batch[b].coord_sets.__delitem__(0) continue #testing gmaker.forward(batch, input_tensor,0,random_rotation=False) output = model(input_tensor) #sending true RMSD values for grid gradientss labels=labels.unsqueeze(1) #hook=output.register_hook(lambda grad: labels) gradspred, = torch.autograd.grad(output, input_tensor, grad_outputs=output.data.new(output.shape).fill_(1), create_graph=True) gradspred=gradspred.detach() atomic_grad_losses = [] total_losses=[] #rmsd losses for the entire batch rmsd_losses = (labels - output) ** 2 rmsd_losses=rmsd_losses.detach().cpu() for b in range(batch_size): atomic_grads = torch.zeros(disp_vecs[b].shape, dtype=torch.float32, device='cuda') if not torch.allclose(disp_vecs[b],atomic_grads): gmaker.backward(batch[b].coord_sets[-1].center(), batch[b].coord_sets[-1], gradspred[b,:14], atomic_grads) pred_grads=F.normalize(atomic_grads,p=1) true_grads=F.normalize(disp_vecs[b],p=1) #atomic grad loss per example atomic_grads_loss=torch.mean(criteria(true_grads,pred_grads),dim=0).detach().cpu() #total_loss for batch total_losses.append((rmsd_losses[b] + 10 * atomic_grads_loss)) #atomic loss for batch atomic_grad_losses.append(atomic_grads_loss) #mean losses from all batches total_loss_mean.append(torch.mean(torch.stack(total_losses)).cpu()) rmsd_loss_mean.append(torch.mean(rmsd_losses).cpu()) atomic_grads_loss_mean.append(torch.mean(torch.stack(atomic_grad_losses)).cpu()) #mean loss for testing session total_test_loss_mean=torch.mean(torch.stack(total_loss_mean)).cpu() rmsd_test_loss_mean = torch.mean(torch.stack(rmsd_loss_mean)).cpu() atomic_test_grads_loss_mean = torch.mean(torch.stack(atomic_grads_loss_mean)).cpu() return total_test_loss_mean,rmsd_test_loss_mean,atomic_test_grads_loss_mean checkpoint=None if args.checkpoint: checkpoint=torch.load(args.checkpoint) initialize_model(model,args) wandb.watch(model) iterations = args.iterations test_interval = args.test_interval batch_size=args.batch_size percent_reduced= args.percent_reduced outprefix=args.outprefix prev_total_loss_snap='' prev_rmsd_loss_snap='' prev_grad_loss_snap='' prev_snap='' initial=0 if args.checkpoint: initial=checkpoint['Iteration'] last_test=0 if 'SGD' in args.solver: optimizer=torch.optim.SGD(model.parameters(),lr=args.base_lr,momentum=args.momentum,weight_decay=args.weight_decay) elif 'Nesterov' in args.solver: optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay,nesterov=True) elif 'Adam' in args.solver: optimizer = torch.optim.Adam(model.parameters(),lr=args.base_lr,weight_decay=args.weight_decay) else: print("No valid solver argument passed (SGD, Adam, Nesterov)") sys.exit(1) if args.checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',factor=args.step_reduce,patience=args.step_when,verbose=True) if args.checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) Bests={} Bests['train_iteration']=0 Bests['test_grad_loss']=torch.from_numpy(np.asarray(np.inf)) Bests['test_rmsd_loss']=torch.from_numpy(np.asarray(np.inf)) Bests['test_total_loss']=torch.from_numpy(np.asarray(np.inf)) if args.checkpoint: Bests=checkpoint['Bests'] dims = gmaker.grid_dimensions(eptrain.num_types()) tensor_shape = (batch_size,) + dims model.cuda() input_tensor = torch.zeros(tensor_shape, dtype=torch.float32, device='cuda',requires_grad=True).contiguous() float_labels = torch.zeros(batch_size, dtype=torch.float32,device='cuda').contiguous() gradspred_grad = torch.zeros(tensor_shape, dtype=torch.float32, device='cuda',requires_grad=False).contiguous() criteria=torch.nn.MSELoss(size_average=None, reduce=None, reduction='none') cos = nn.CosineSimilarity(dim=1, eps=1e-10) for i in range(initial,iterations): batch = eptrain.next_batch(batch_size) transformers=[] disp_vecs=[] batch.extract_label(0, float_labels) labels = float_labels.to('cuda') for b in range(batch_size): #faulty examples (don't know why) try: disp_vecs.append(batch[b].coord_sets[0].coords.tonumpy() - batch[b].coord_sets[2].coords.tonumpy()) disp_vecs[-1]=torch.from_numpy(disp_vecs[-1]).cuda() batch[b].coord_sets.__delitem__(0) except: print(batch[b].coord_sets[0].coords.tonumpy().shape,batch[b].coord_sets[2].coords.tonumpy().shape) disp_vecs.append(torch.zeros(batch[b].coord_sets[2].coords.tonumpy().shape,dtype=torch.float32).cuda()) batch[b].coord_sets.__delitem__(0) transformer=molgrid.Transform(batch[b].coord_sets[-1].center(),6,True) #doesnt change underlying coordinate gmaker.forward(batch[b],transformer,input_tensor[b]) transformers.append(transformer) labels = labels.reshape(batch_size,1) labels=labels.contiguous() optimizer.zero_grad() output = model(input_tensor) #sending true RMSD values for grid gradientss #hook=output.register_hook(lambda grad: labels) gradspred, = torch.autograd.grad(output, input_tensor, grad_outputs=output.data.new(output.shape).fill_(1), create_graph=True) #losses for the batch atomic_grad_losses = [] #rmsd_losses = (labels-output) ** 2 labels=labels.contiguous() rmsd_losses=criteria(output,labels) total_losses=[] for b in range(batch_size): atomic_grads=torch.zeros(disp_vecs[b].shape, dtype=torch.float32, device='cuda',requires_grad=True) atomic_grads1=torch.zeros(disp_vecs[b].shape, dtype=torch.float32, device='cuda') #apply transform to underlying coords transformers[b].forward(batch[b],batch[b]) if not torch.allclose(disp_vecs[b],atomic_grads): gmaker.backward(transformers[b].get_rotation_center(),batch[b].coord_sets[-1],gradspred[b,:14],atomic_grads) pred_grads=F.normalize(atomic_grads,p=1) true_grads=F.normalize(disp_vecs[b],p=1) cost=criteria(pred_grads,true_grads) #cost=1 - torch.cosine_similarity(atomic_grads,disp_vecs[b],dim=1) cost.mean().backward() batch[b].coord_sets[1].make_vector_types() type_grad = torch.zeros(batch[b].coord_sets[1].type_vector.tonumpy().shape,dtype=torch.float32,device='cuda') gmaker.backward_gradients(transformers[b].get_rotation_center(),batch[b].coord_sets[1],gradspred[b,:14],atomic_grads.grad,type_grad,gradspred_grad[b,:14],atomic_grads1,type_grad) #atomic loss for example atomic_grads_loss=cost.mean().cpu() #atomic losses for batch atomic_grad_losses.append(atomic_grads_loss) #total losses for batch total_losses.append(rmsd_losses[b]+ 10*atomic_grads_loss) #total loss for batch total_loss_mean = torch.mean(torch.stack(total_losses).cpu()) #rmsd loss for batch rmsd_loss_mean = torch.mean(rmsd_losses) #atomic loss for batch atomic_grads_loss_mean = torch.mean(torch.stack(atomic_grad_losses).cpu()) '''gradspred*=gradspred_grad gradspred=gradspred.contiguous() loss = rmsd_losses.mean() + gradspred.contiguous().sum() loss = loss.contiguous() input_tensor=input_tensor.contiguous() loss.backward()''' gradspred_grad = gradspred_grad * args.weight gradspred.backward(gradspred_grad,retain_graph=True) #hook.remove() rmsd_losses.mean().backward() nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients) optimizer.step() wandb.log({'total_train_loss': torch.sqrt(total_loss_mean), 'iteration': i+1, 'rmsd_train_loss': torch.sqrt(rmsd_loss_mean),'atomic_grad_train_loss': atomic_grads_loss_mean}) if i%test_interval==0 and i!=0 : total_loss_mean, rmsd_loss_mean, atomic_grads_loss_mean = test_model(model, eptest, gmaker, percent_reduced,batch_size) scheduler.step(total_loss_mean) print('done') if total_loss_mean<Bests['test_total_loss']: Bests['test_total_loss']=total_loss_mean wandb.run.summary["total_test_test_loss"]=torch.sqrt(Bests['test_total_loss']) Bests['train_iteration']=i if Bests['train_iteration']-i>=args.step_when and optimizer.param_groups[0]['lr']<= ((args.step_reduce)**args.step_end_cnt)*args.base_lr: last_test=1 torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'Bests': Bests, 'Iteration': i+1}, outprefix+'_best_total_'+str(i+1)+'.pth.tar') if prev_total_loss_snap: os.remove(prev_total_loss_snap) prev_total_loss_snap=outprefix+'_best_total_'+str(i+1)+'.pth.tar' if rmsd_loss_mean<Bests['test_rmsd_loss']: Bests['test_rmsd_loss']=rmsd_loss_mean wandb.run.summary["rmsd_test_test_loss"]=torch.sqrt(Bests['test_rmsd_loss']) torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'Bests': Bests, 'Iteration': i + 1},outprefix+'_best_rmsd_'+str(i+1)+'.pth.tar') if prev_rmsd_loss_snap: os.remove(prev_rmsd_loss_snap) prev_rmsd_loss_snap = outprefix + '_best_rmsd_' + str(i + 1) + '.pth.tar' if atomic_grads_loss_mean<Bests['test_grad_loss']: Bests['test_grad_loss']=atomic_grads_loss_mean wandb.run.summary["atomic_grad_test_test_loss"]=Bests['test_grad_loss'] torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'Bests': Bests, 'Iteration': i + 1},outprefix+'_best_atom_'+str(i+1)+'.pth.tar') if prev_grad_loss_snap: os.remove(prev_grad_loss_snap) prev_grad_loss_snap = outprefix + '_best_atom_' + str(i + 1) + '.pth.tar' print( "Iteration {}, total_test_loss: {:.3f},rmsd_test_loss: {:.3f},grad_test_loss: {:.3f}, Best_total_loss: {:.3f},Best_rmsd_loss: {:.3f},Best_grad_loss: {:.3f},learning_Rate: {:.7f}".format( i + 1, torch.sqrt(total_loss_mean),torch.sqrt(rmsd_loss_mean),torch.sqrt(atomic_grads_loss_mean), torch.sqrt(Bests['test_total_loss']),torch.sqrt(Bests['test_rmsd_loss']),torch.sqrt(Bests['test_grad_loss']),optimizer.param_groups[0]['lr'])) wandb.log({'total_test_test_loss': torch.sqrt(total_loss_mean), 'iteration': i + 1,'rmsd_test_test_loss': torch.sqrt(rmsd_loss_mean),'atomic_grad_test_test_loss': atomic_grads_loss_mean,'learning rate':optimizer.param_groups[0]['lr']}) torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'Bests': Bests, 'Iteration': i + 1}, outprefix + '_' + str(i + 1) + '.pth.tar') if prev_snap: os.remove(prev_snap) prev_snap = outprefix + '_' + str(i + 1) + '.pth.tar' if last_test: return Bests
def forward(self, interpolate=False, spherical=False): assert len(self) > 0, 'data is empty' # get next batch of structures examples = self.ex_provider.next_batch(self.batch_size) labels = torch.zeros(self.batch_size, device=self.device) examples.extract_label(0, labels) # create lists for examples, structs and transforms batch_list = lambda: [None] * self.batch_size input_examples = batch_list() input_rec_structs = batch_list() input_lig_structs = batch_list() input_transforms = batch_list() cond_examples = batch_list() cond_rec_structs = batch_list() cond_lig_structs = batch_list() cond_transforms = batch_list() # create output tensors for atomic density grids input_grids = torch.zeros( self.batch_size, self.n_channels, *self.grid_maker.spatial_grid_dimensions(), dtype=torch.float32, device=self.device, ) cond_grids = torch.zeros( self.batch_size, self.n_channels, *self.grid_maker.spatial_grid_dimensions(), dtype=torch.float32, device=self.device, ) # split examples, create structs and transforms for i, ex in enumerate(examples): if self.diff_cond_structs: # different input and conditional molecules input_rec_coord_set, input_lig_coord_set, \ cond_rec_coord_set, cond_lig_coord_set = ex.coord_sets # split example into inputs and conditions input_ex = molgrid.Example() input_ex.coord_sets.append(input_rec_coord_set) input_ex.coord_sets.append(input_lig_coord_set) cond_ex = molgrid.Example() cond_ex.coord_sets.append(cond_rec_coord_set) cond_ex.coord_sets.append(cond_lig_coord_set) else: # same conditional molecules as input input_rec_coord_set, input_lig_coord_set = ex.coord_sets cond_rec_coord_set, cond_lig_coord_set = ex.coord_sets input_ex = cond_ex = ex # store split examples for gridding input_examples[i] = input_ex cond_examples[i] = cond_ex # convert coord sets to atom structs input_rec_structs[i] = atom_structs.AtomStruct.from_coord_set( input_rec_coord_set, typer=self.rec_typer, data_root=self.root_dir, device=self.device) input_lig_structs[i] = atom_structs.AtomStruct.from_coord_set( input_lig_coord_set, typer=self.lig_typer, data_root=self.root_dir, device=self.device) if self.diff_cond_structs: cond_rec_structs[i] = atom_structs.AtomStruct.from_coord_set( cond_rec_coord_set, typer=self.rec_typer, data_root=self.root_dir, device=self.device) cond_lig_structs[i] = atom_structs.AtomStruct.from_coord_set( cond_lig_coord_set, typer=self.lig_typer, data_root=self.root_dir, device=self.device) else: # same structs as input cond_rec_structs[i] = input_rec_structs[i] cond_lig_structs[i] = input_lig_structs[i] # create input transform input_transforms[i] = molgrid.Transform( center=input_lig_coord_set.center(), random_translate=self.random_translation, random_rotation=self.random_rotation, ) if self.diff_cond_transform: # create conditional transform cond_transforms[i] = molgrid.Transform( center=cond_lig_coord_set.center(), random_translate=self.random_translation, random_rotation=self.random_rotation, ) else: # same transform as input cond_transforms[i] = input_transforms[i] if interpolate: # interpolate conditional transforms # i.e. location and orientation of conditional grid if not self.cond_interp.is_initialized: self.cond_interp.initialize(cond_examples[0]) cond_transforms = self.cond_interp( transforms=cond_transforms, spherical=spherical, ) # create density grids for i in range(self.batch_size): # create input density grid self.grid_maker.forward(input_examples[i], input_transforms[i], input_grids[i]) if (self.diff_cond_transform or self.diff_cond_structs or interpolate): # create conditional density grid self.grid_maker.forward(cond_examples[i], cond_transforms[i], cond_grids[i]) else: # same density grid as input cond_grids[i] = input_grids[i] input_structs = (input_rec_structs, input_lig_structs) cond_structs = (cond_rec_structs, cond_lig_structs) transforms = (input_transforms, cond_transforms) return (input_grids, cond_grids, input_structs, cond_structs, transforms, labels)
def test_a_grid(): fname = datadir+"/small.types" e = molgrid.ExampleProvider(data_root=datadir+"/structs") e.populate(fname) ex = e.next() c = ex.coord_sets[1] assert np.min(c.type_index.tonumpy()) >= 0 gmaker = molgrid.GridMaker() dims = gmaker.grid_dimensions(c.max_type) # this should be grid_dims or get_grid_dims center = c.center() center = tuple(center) mgridout = molgrid.MGrid4f(*dims) mgridgpu = molgrid.MGrid4f(*dims) npout = np.zeros(dims, dtype=np.float32) torchout = torch.zeros(dims, dtype=torch.float32) cudaout = torch.zeros(dims, dtype=torch.float32, device='cuda') gmaker.forward(center, c, mgridout.cpu()) gmaker.forward(center, c, mgridgpu.gpu()) gmaker.forward(center, c, npout) gmaker.forward(center, c, torchout) gmaker.forward(center, c, cudaout) newt = gmaker.make_tensor(center, c) newa = gmaker.make_ndarray(center, c) assert 1.438691 == approx(mgridout.tonumpy().max()) assert 1.438691 == approx(mgridgpu.tonumpy().max()) assert 1.438691 == approx(npout.max()) assert 1.438691 == approx(torchout.numpy().max()) assert 1.438691 == approx(cudaout.cpu().numpy().max()) assert 1.438691 == approx(newt.cpu().numpy().max()) assert 1.438691 == approx(newa.max()) #should overwrite by default, yes? gmaker.forward(center, c, mgridout.cpu()) gmaker.forward(center, c, mgridgpu.gpu()) assert 1.438691 == approx(mgridout.tonumpy().max()) assert 1.438691 == approx(mgridgpu.tonumpy().max()) dims = gmaker.grid_dimensions(e.num_types()) mgridout = molgrid.MGrid4f(*dims) mgridgpu = molgrid.MGrid4f(*dims) #pass transform gmaker.forward(ex, molgrid.Transform(center, 0, False), mgridout.cpu()) gmaker.forward(ex, molgrid.Transform(center, 0, False), mgridgpu.gpu()) assert 2.094017 == approx(mgridout.tonumpy().max()) assert 2.094017 == approx(mgridgpu.tonumpy().max()) gmaker.forward(ex, mgridout.cpu()) gmaker.forward(ex, mgridgpu.gpu()) gmaker.forward(ex, mgridout.cpu()) gmaker.forward(ex, mgridgpu.gpu()) assert 2.094017 == approx(mgridout.tonumpy().max()) assert 2.094017 == approx(mgridgpu.tonumpy().max())
def train(train_loader, model, criterion, optimizer, gmaker, tensorshape, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') # top1 = AverageMeter('Acc@1', ':6.2f') # top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() total_loss = 0 end = time.time() for i, (lengths, center, coords, types, radii, _) in enumerate(train_loader): types = types.cuda(args.gpu, non_blocking=True) radii = radii.squeeze().cuda(args.gpu, non_blocking=True) coords = coords.cuda(args.gpu, non_blocking=True) coords_q = torch.empty(*coords.shape, device=coords.device, dtype=coords.dtype) batch_size = coords.shape[0] if batch_size != types.shape[0] or batch_size != radii.shape[0]: raise RuntimeError("Inconsistent batch sizes in dataset outputs") output1 = torch.empty(batch_size, *tensorshape, dtype=coords.dtype, device=coords.device) output2 = torch.empty(batch_size, *tensorshape, dtype=coords.dtype, device=coords.device) for idx in range(batch_size): t = molgrid.Transform( molgrid.float3(*(center[idx].numpy().tolist())), random_translate=2, random_rotation=True) t.forward(coords[idx][:lengths[idx]], coords_q[idx][:lengths[idx]]) gmaker.forward(t.get_rotation_center(), coords_q[idx][:lengths[idx]], types[idx][:lengths[idx]], radii[idx][:lengths[idx]], molgrid.tensor_as_grid(output1[idx])) t.forward(coords[idx][:lengths[idx]], coords[idx][:lengths[idx]]) gmaker.forward(t.get_rotation_center(), coords[idx][:lengths[idx]], types[idx][:lengths[idx]], radii[idx][:lengths[idx]], molgrid.tensor_as_grid(output2[idx])) del lengths, center, coords, types, radii torch.cuda.empty_cache() # measure data loading time data_time.update(time.time() - end) # compute output output, target = model(im_q=output1, im_k=output2) loss = criterion(output, target) total_loss += float(loss.item()) # acc1/acc5 are (K+1)-way contrast classifier accuracy # measure accuracy and record loss # acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), output1.size(0)) # top1.update(acc1[0], images[0].size(0)) # top5.update(acc5[0], images[0].size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) return total_loss / len(train_loader.dataset)
def train(train_loader, model, criterion, optimizer, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') r = AverageMeter('Pearson R', ':6.2f') rmse = AverageMeter('RMSE', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses], prefix="Epoch: [{}]".format(epoch)) targets = [] predictions = [] """ Switch to eval mode: Under the protocol of linear classification on frozen features/models, it is not legitimate to change any part of the pre-trained model. BatchNorm in train mode may revise running mean/std (even if it receives no gradient), which are part of the model parameters too. """ model.eval() end = time.time() total_loss = 0 end = time.time() for i, (lengths, center, coords, types, radii, labels) in enumerate(train_loader): types = types.cuda(args.gpu, non_blocking=True) radii = radii.squeeze().cuda(args.gpu, non_blocking=True) coords = coords.cuda(args.gpu, non_blocking=True) batch_size = coords.shape[0] if batch_size != types.shape[0] or batch_size != radii.shape[0]: raise RuntimeError("Inconsistent batch sizes in dataset outputs") output1 = torch.empty(batch_size,*tensorshape,dtype=coords.dtype,device=coords.device) for idx in range(batch_size): t = molgrid.Transform(molgrid.float3(*(center[idx].numpy().tolist())),random_translate=2,random_rotation=True) t.forward(coords[idx][:lengths[idx]],coords_q[idx][:lengths[idx]]) gmaker.forward(t.get_rotation_center(), coords_q[idx][:lengths[idx]], types[idx][:lengths[idx]], radii[idx][:lengths[idx]], molgrid.tensor_as_grid(output1[idx])) del lengths, center, coords, types, radii torch.cuda.empty_cache() target = labels.cuda(args.gpu, non_blocking=True) # compute output prediction = model(output1) loss = criterion(prediction, target) # measure accuracy and record loss r_val, rmse_val = accuracy(prediction, target) losses.update(loss.item(), output1.size(0)) r.update(r_val, output1.size(0)) rmse.update(rmse_val, output1.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() predictions += prediction.detach().flatten().tolist() targets += target.detach().flatten().tolist() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) r_avg, rmse_avg = accuracy(predictions,targets) return r_avg, rmse_avg