def parse_score_fn(model, parses): drawings = nested_map(lambda x: get_stk_from_bspline(x), parses) if torch.cuda.is_available(): drawings = nested_map(lambda x: x.cuda(), drawings) parses = nested_map(lambda x: x.cuda(), parses) losses = model.losses_fn( parses, drawings, filter_small=False, denormalize=True) return -losses.cpu()
def token_losses_fn(self, parses): eps_shape = [list(parse.shape_noise) for parse in parses] eps_loc = [list(parse.loc_noise) for parse in parses] affine = [parse.affine for parse in parses] if self.gpu: eps_shape = nested_map(to_cuda, eps_shape) eps_loc = nested_map(to_cuda, eps_loc) affine = nested_map(to_cuda, affine) losses = -self.token_model.log_prob_multi(eps_shape, eps_loc, affine) return losses
def type_losses_fn(self, parses, drawings): splines_list = [list(parse.x) for parse in parses] if self.gpu: drawings = nested_map(to_cuda, drawings) splines_list = nested_map(to_cuda, splines_list) if self.drawings_to_type: losses = self.type_model.losses_fn( splines_list, drawings, denormalize=self.denormalize) else: losses = self.type_model.losses_fn( splines_list, denormalize=self.denormalize) return losses