def _train_epoch(model, optimizer, iterator, data, shuffle=True, lr_scheduler=None):
    model.train()

    total_loss = 0.0
    generator = iterator(data, shuffle=shuffle)
    num_batches = iterator.get_num_batches(data)
    batch_counter = 0
    summary_interval = max(min(50, int(num_batches / 5)), 1)

    for batch in generator:
        batch_counter += 1
        optimizer.zero_grad()
        loss = _batch_loss(model, batch)
        if torch.isnan(loss):
            raise ValueError("nan loss encountered")

        loss.backward()
        total_loss += loss.item()
        rescale_gradients(model)
        optimizer.step()

        metrics = model.get_metrics(reset=False)
        metrics["loss"] = float(total_loss / batch_counter) if batch_counter > 0 else 0.0

        if batch_counter % summary_interval == 0 or batch_counter == num_batches:
            print_out("%d out of %d batches, loss: %.3f" % (batch_counter, num_batches, metrics["loss"]))

    metrics = model.get_metrics(reset=True)
    metrics["loss"] = float(total_loss / batch_counter) if batch_counter > 0 else 0.0
    return metrics
Example #2
0
 def replay_sample(self):
     '''Samples a batch from memory'''
     batches = [body.replay_memory.sample() for body in self.agent.nanflat_body_a]
     batch = util.concat_batches(batches)
     batch = util.to_torch_batch(batch, self.net.gpu)
     assert not torch.isnan(batch['states']).any()
     return batch
Example #3
0
    def pytorch_net_to_buffer(pytorch_net, input_dim, model_on_gpu, float_input=True):
        """Traces a pytorch net and outputs a python buffer object
        holding net."""

        training = pytorch_net.training
        pytorch_net.train(False)

        for name, p in pytorch_net.named_parameters():
            inf_count = torch.isinf(p).sum().item()
            nan_count = torch.isnan(p).sum().item()
            assert inf_count + nan_count == 0, "{} has {} inf and {} nan".format(
                name, inf_count, nan_count
            )

        if float_input:
            dtype = torch.cuda.FloatTensor if model_on_gpu else torch.FloatTensor
            dummy_input = torch.randn(1, input_dim).type(dtype)
        else:
            dtype = torch.cuda.LongTensor if model_on_gpu else torch.LongTensor
            dummy_input = torch.randint(low=0, high=1, size=(1, input_dim)).type(dtype)

        write_buffer = BytesIO()
        try:
            torch.onnx.export(pytorch_net, dummy_input, write_buffer)
        finally:
            pytorch_net.train(training)
        return write_buffer
Example #4
0
def torch_isnan(x):
    """
    A convenient function to check if a Tensor contains all nan; also works with numbers
    """
    if isinstance(x, numbers.Number):
        return x != x
    return torch.isnan(x).all()
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
    """ Utility function for optimize_on_cpu and 16-bits training.
        Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
    """
    is_nan = False
    for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
        if name_opti != name_model:
            logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
            raise ValueError
        if test_nan and torch.isnan(param_model.grad).sum() > 0:
            is_nan = True
        if param_opti.grad is None:
            param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
        param_opti.grad.data.copy_(param_model.grad.data)
    return is_nan
Example #6
0
 def training_step(self, x=None, y=None, loss=None, retain_graph=False):
     '''Takes a single training step: one forward and one backwards pass'''
     self.train()
     self.zero_grad()
     self.optim.zero_grad()
     if loss is None:
         out = self(x)
         loss = self.loss_fn(out, y)
     assert not torch.isnan(loss).any()
     if net_util.to_assert_trained():
         assert_trained = net_util.gen_assert_trained(self.conv_model)
     loss.backward(retain_graph=retain_graph)
     if self.clip_grad:
         logger.debug(f'Clipping gradient')
         torch.nn.utils.clip_grad_norm(self.parameters(), self.clip_grad_val)
     self.optim.step()
     if net_util.to_assert_trained():
         assert_trained(self.conv_model)
     return loss
Example #7
0
    def attack_FGSM(self,
                    label,
                    out_dir=None,
                    save_title=None,
                    steps=5,
                    vertex_eps=0.001,
                    pose_eps=0.05,
                    lighting_eps=4000,
                    vertex_attack=True,
                    pose_attack=True,
                    lighting_attack=False):
        if out_dir is not None and save_title is None:
            raise Exception("Must provide image title if out dir is provided")
        elif save_title is not None and out_dir is None:
            raise Exception("Must provide directory if image is to be saved")

        filename = save_title

        # classify
        img = self.render_image(out_dir=out_dir, filename=filename)
        # only there to zero out gradients.
        optimizer = torch.optim.Adam(
            [self.translation, self.euler_angles, self.light.intensity], lr=0)

        for i in range(steps):
            optimizer.zero_grad()
            pred, net_out = self.classify(img)
            if pred.item() != label and i != 0:
                print("misclassification at step ", i)
                final_image = np.clip(
                    img[0].permute(1, 2, 0).data.cpu().numpy(), 0, 1)
                return pred, final_image
            # get gradients
            self._get_gradients(img.cpu(), net_out, label)

            delta = 1e-6
            inf_count = 0
            nan_count = 0

            # attack each shape's vertices
            if vertex_attack:
                for shape in self.shapes:
                    if not torch.isfinite(shape.vertices.grad).all():
                        inf_count += 1
                    elif torch.isnan(shape.vertices.grad).any():
                        nan_count += 1
                    else:
                        # subtract because we are trying to decrease the classification score of the label
                        shape.vertices -= torch.sign(
                            shape.vertices.grad /
                            (torch.norm(shape.vertices.grad) +
                             delta)) * vertex_eps

            if pose_attack:
                self.euler_angles.data -= torch.sign(
                    self.euler_angles.grad /
                    (torch.norm(self.euler_angles.grad) + delta)) * pose_eps

            if lighting_attack:
                light_sub = torch.sign(self.light.intensity.grad /
                                       (torch.norm(self.light.intensity.grad) +
                                        delta)) * lighting_eps
                light_sub = torch.min(self.light.intensity.data, light_sub)
                self.light.intensity.data -= light_sub

            img = self.render_image(out_dir=out_dir, filename=filename)

        final_pred, net_out = self.classify(img)
        final_image = np.clip(img[0].permute(1, 2, 0).data.cpu().numpy(), 0, 1)
        return final_pred, final_image
Example #8
0
def _warn_if_nan(tensor, name):
    if torch.isnan(tensor).any():
        warnings.warn('Encountered nan elements in {}'.format(name))
Example #9
0
    def forward(self,
                im_data,
                im_info,
                gt_objects=None,
                gt_relationships=None,
                rpn_anchor_targets_obj=None):
        # timing the process
        base_timer = Timer()
        mps_timer = Timer()
        infer_timer = Timer()
        assert im_data.size(0) == 1, "Only support Batch Size equals 1"
        base_timer.tic()
        # Currently, RPN support batch but not for MSDN
        features, object_rois, rpn_losses = self.rpn(
            im_data, im_info, rpn_data=rpn_anchor_targets_obj)
        # pdb.set_trace()
        if self.training:
            roi_data_object, roi_data_predicate, roi_data_region, mat_object, mat_phrase, mat_region = \
                self.proposal_target_layer(object_rois, gt_objects[0], gt_relationships[0], self.n_classes_obj)
            object_rois = roi_data_object[1]
            region_rois = roi_data_region[1]
        else:
            object_rois, region_rois, mat_object, mat_phrase, mat_region = self.graph_construction(
                object_rois, )
        # roi pool
        pooled_object_features = self.roi_pool_object(features,
                                                      object_rois).view(
                                                          len(object_rois), -1)
        pooled_object_features = self.fc_obj(pooled_object_features)
        # print 'fc7_object.std', pooled_object_features.data.std()

        pooled_region_features = self.roi_pool_region(features, region_rois)
        pooled_region_features = self.fc_region(pooled_region_features)

        bbox_object = self.bbox_obj(F.relu(pooled_object_features))
        base_timer.toc()

        mps_timer.tic()

        for i, mps in enumerate(self.mps_list):
            pooled_object_features, pooled_region_features = \
                mps(pooled_object_features, pooled_region_features, mat_object, mat_region, object_rois, region_rois)

        mps_timer.toc()

        infer_timer.tic()
        cls_score_object = self.score_obj(F.relu(pooled_object_features))
        if self.con_net.args.CON_use == '1':
            pooled_object_features = self.con_net(pooled_object_features,
                                                  cls_score_object)
        pooled_phrase_features = self.phrase_inference(pooled_object_features,
                                                       pooled_region_features,
                                                       mat_phrase)
        infer_timer.toc()

        # cls_score_object = self.score_obj(F.relu(pooled_object_features))
        cls_prob_object = F.softmax(cls_score_object, dim=1)
        cls_score_predicate = self.score_pred(F.relu(pooled_phrase_features))
        cls_prob_predicate = F.softmax(cls_score_predicate, dim=1)

        # reconstruction
        if self.reconstruction_net.args.RC_use == '1':
            re_img = self.reconstruction_net(pooled_object_features,
                                             object_rois)

        if TIME_IT:
            print('TIMING:')
            print('[CNN]:\t{0:.3f} s'.format(base_timer.average_time))
            print('[MPS]:\t{0:.3f} s'.format(mps_timer.average_time))
            print('[INF]:\t{0:.3f} s'.format(infer_timer.average_time))

        pdb.set_trace()
        # object cls loss
        loss_cls_obj, (tp, tf, fg_cnt, bg_cnt) = \
                build_loss_cls(cls_score_object, roi_data_object[0],
                    loss_weight=self.object_loss_weight.to(cls_score_object.get_device()))
        # object regression loss
        loss_reg_obj = build_loss_bbox(bbox_object, roi_data_object, fg_cnt)
        # predicate cls loss
        loss_cls_rel,  (tp_pred, tf_pred, fg_cnt_pred, bg_cnt_pred)= \
                build_loss_cls(cls_score_predicate, roi_data_predicate[0],
                    loss_weight=self.predicate_loss_weight.to(cls_score_predicate.get_device()))

        # AEcoder
        reconstruction_loss = None
        if self.reconstruction_net.args.RC_use == '1':
            reconstruction_loss = build_loss_reconstruct(re_img, im_data)

        # GAN
        disc_loss = None
        if self.reconstruction_net.args.GAN_use == '1':
            discriminator_real = self.discriminator(im_data)
            discriminator_fake = self.discriminator(re_img)
            disc_loss = build_loss_GAN(
                torch.cat([discriminator_real, discriminator_fake], 0))

        losses = {
            'rpn':
            rpn_losses,
            'loss_cls_obj':
            loss_cls_obj,
            'loss_reg_obj':
            torch.zeros_like(loss_reg_obj)
            if torch.isnan(loss_reg_obj) else loss_reg_obj,
            'loss_cls_rel':
            loss_cls_rel,
            'tf':
            tf,
            'tp':
            tp,
            'fg_cnt':
            fg_cnt,
            'bg_cnt':
            bg_cnt,
            'tp_pred':
            tp_pred,
            'tf_pred':
            tf_pred,
            'fg_cnt_pred':
            fg_cnt_pred,
            'bg_cnt_pred':
            bg_cnt_pred,
        }

        if self.reconstruction_net.args.RC_use == '1':
            losses['reconstruction_loss'] = reconstruction_loss

        if self.reconstruction_net.args.GAN_use == '1':
            losses['disc_loss'] = disc_loss
        # loss for NMS
        if self.learnable_nms:
            duplicate_labels = roi_data_object[4][:, 1:2]
            duplicate_weights = roi_data_object[4][:, 0:1]
            if duplicate_weights.data.sum() == 0:
                loss_nms = loss_cls_rel * 0  # Guarentee the data type
            else:
                mask = torch.zeros_like(cls_prob_object).byte()
                for i in range(duplicate_labels.size(0)):
                    mask[i, roi_data_object[0].data[i][0]] = 1
                selected_prob = torch.masked_select(cls_prob_object, mask)
                reranked_score = self.nms(pooled_object_features,
                                          selected_prob, roi_data_object[1])
                selected_prob = selected_prob.unsqueeze(1) * reranked_score
                loss_nms = F.binary_cross_entropy(
                    selected_prob,
                    duplicate_labels,
                    weight=duplicate_weights,
                    size_average=False) / (duplicate_weights.data.sum() +
                                           1e-10)
            losses["loss_nms"] = loss_nms

        losses['loss'] = self.loss(losses)

        return losses
Example #10
0
    def get_clutter_model(self, compnet_type, vMF_kappa):
        idir = 'background_images/'
        vc_num = self.conv1o1.weight.shape[0]

        updated_models = torch.zeros((0, vc_num))
        boo_gpu = (self.conv1o1.weight.device.type == 'cuda')
        gpu_id = self.conv1o1.weight.device.index
        if boo_gpu:
            updated_models = updated_models.cuda(gpu_id)

        if self.compnet_type == 'vmf':
            occ_types = occ_types_vmf
        elif self.compnet_type == 'bernoulli':
            occ_types = occ_types_bern

        for j in range(len(occ_types)):
            occ_type = occ_types[j]
            with torch.no_grad():
                files = glob.glob(idir + '*' + occ_type + '.JPEG')
                clutter_feats = torch.zeros((0, vc_num))
                if boo_gpu:
                    clutter_feats = clutter_feats.cuda(gpu_id)
                for i in range(len(files)):
                    file = files[i]
                    img, _ = imgLoader(file, [[]],
                                       bool_resize_images=False,
                                       bool_square_images=False)
                    if boo_gpu:
                        img = img.cuda(gpu_id)

                    feats = self.activation_layer(
                        self.conv1o1(
                            self.backbone(
                                img.reshape(1, img.shape[0], img.shape[1],
                                            img.shape[2]))))[0].transpose(
                                                1, 2)
                    feats_reshape = torch.reshape(feats,
                                                  [vc_num, -1]).transpose(
                                                      0, 1)
                    clutter_feats = torch.cat((clutter_feats, feats_reshape))

                mean_activation = torch.reshape(
                    torch.sum(clutter_feats, dim=1),
                    (-1, 1)).repeat([1, vc_num])
                if compnet_type == 'bernoulli':
                    boo = torch.sum(mean_activation, dim=1) != 0
                    mean_vec = torch.mean(clutter_feats[boo] /
                                          mean_activation[boo],
                                          dim=0)
                    updated_models = torch.cat(
                        (updated_models, mean_vec.reshape(1, -1)))
                else:
                    if (occ_type == '_white' or occ_type == '_noise'):
                        mean_vec = torch.mean(clutter_feats / mean_activation,
                                              dim=0)
                        updated_models = torch.cat(
                            (updated_models, mean_vec.reshape(1, -1)))
                    else:
                        nc = 5
                        model = vMFMM(nc, 'k++')
                        model.fit(clutter_feats.cpu().numpy(),
                                  vMF_kappa,
                                  max_it=150,
                                  tol=1e-10)
                        mean_vec = torch.zeros(
                            nc, clutter_feats.shape[1]).cuda(gpu_id)
                        mean_act = torch.zeros(
                            nc, clutter_feats.shape[1]).cuda(gpu_id)
                        clust_cnt = torch.zeros(nc)
                        for v in range(model.p.shape[0]):
                            assign = np.argmax(model.p[v])
                            mean_vec[assign] += clutter_feats[v]
                            clust_cnt[assign] += 1

                        mean_vec_final = torch.zeros(
                            sum(clust_cnt > 0),
                            clutter_feats.shape[1]).cuda(gpu_id)
                        cnt = 0
                        for v in range(mean_vec.shape[0]):
                            if clust_cnt[v] > 0:
                                mean_vec_final[cnt] = (
                                    mean_vec[v] /
                                    clust_cnt[v].cuda(gpu_id)).t()
                        updated_models = torch.cat(
                            (updated_models, mean_vec_final))

                        if torch.isnan(updated_models.min()):
                            print('ISNAN IN CLUTTER MODEL')

        return updated_models
Example #11
0
def train(train_loader, net, criterion, optimizer, curr_epoch, writer):
    '''
    Runs the training loop per epoch
    train_loader: Data loader for train
    net: thet network
    criterion: loss fn
    optimizer: optimizer
    curr_epoch: current epoch 
    writer: tensorboard writer
    return: val_avg for step function if required
    '''
    net.train()

    train_main_loss = AverageMeter()
    train_edge_loss = AverageMeter()
    train_seg_loss = AverageMeter()
    train_att_loss = AverageMeter()
    train_dual_loss = AverageMeter()
    curr_iter = curr_epoch * len(train_loader)

    for i, data in enumerate(train_loader):
        if i == 0:
            print('running....')

        inputs, mask, edge, _img_name = data

        if torch.sum(torch.isnan(inputs)) > 0:
            import pdb
            pdb.set_trace()

        batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)

        inputs, mask, edge = inputs.cuda(), mask.cuda(), edge.cuda()

        if i == 0:
            print('forward done')

        optimizer.zero_grad()

        main_loss = None
        loss_dict = None

        if args.joint_edgeseg_loss:
            loss_dict = net(inputs, gts=(mask, edge))

            if args.seg_weight > 0:
                log_seg_loss = loss_dict['seg_loss'].mean().clone().detach_()
                train_seg_loss.update(log_seg_loss.item(), batch_pixel_size)
                main_loss = loss_dict['seg_loss']

            if args.edge_weight > 0:
                log_edge_loss = loss_dict['edge_loss'].mean().clone().detach_()
                train_edge_loss.update(log_edge_loss.item(), batch_pixel_size)
                if main_loss is not None:
                    main_loss += loss_dict['edge_loss']
                else:
                    main_loss = loss_dict['edge_loss']

            if args.att_weight > 0:
                log_att_loss = loss_dict['att_loss'].mean().clone().detach_()
                train_att_loss.update(log_att_loss.item(), batch_pixel_size)
                if main_loss is not None:
                    main_loss += loss_dict['att_loss']
                else:
                    main_loss = loss_dict['att_loss']

            if args.dual_weight > 0:
                log_dual_loss = loss_dict['dual_loss'].mean().clone().detach_()
                train_dual_loss.update(log_dual_loss.item(), batch_pixel_size)
                if main_loss is not None:
                    main_loss += loss_dict['dual_loss']
                else:
                    main_loss = loss_dict['dual_loss']

        else:
            main_loss = net(inputs, gts=mask)

        main_loss = main_loss.mean()
        log_main_loss = main_loss.clone().detach_()

        train_main_loss.update(log_main_loss.item(), batch_pixel_size)

        main_loss.backward()

        optimizer.step()

        if i == 0:
            print('step 1 done')

        curr_iter += 1

        if args.local_rank == 0:
            msg = '[epoch {}], [iter {} / {}], [train main loss {:0.6f}], [seg loss {:0.6f}], [edge loss {:0.6f}], [lr {:0.6f}]'.format(
                curr_epoch, i + 1, len(train_loader), train_main_loss.avg,
                train_seg_loss.avg, train_edge_loss.avg,
                optimizer.param_groups[-1]['lr'])

            logging.info(msg)

            # Log tensorboard metrics for each iteration of the training phase
            writer.add_scalar('training/loss', (train_main_loss.val),
                              curr_iter)
            writer.add_scalar('training/lr', optimizer.param_groups[-1]['lr'],
                              curr_iter)
            if args.joint_edgeseg_loss:

                writer.add_scalar('training/seg_loss', (train_seg_loss.val),
                                  curr_iter)
                writer.add_scalar('training/edge_loss', (train_edge_loss.val),
                                  curr_iter)
                writer.add_scalar('training/att_loss', (train_att_loss.val),
                                  curr_iter)
                writer.add_scalar('training/dual_loss', (train_dual_loss.val),
                                  curr_iter)
        if i > 5 and args.test_mode:
            return
Example #12
0
def assert_nan_inf(input_tensors: List[torch.Tensor]) -> None:
    for tensor in input_tensors:
        if torch.isnan(tensor).any() or torch.isinf(tensor).any():
            print(f"nan or inf found in: {tensor}")
            raise Exception("found nan or inf")
Example #13
0
    def torch_fit(self,
                  inp_tensor: torch.Tensor,
                  validation_batch: torch.Tensor,
                  lr: float = 0.05,
                  epochs: int = 1000,
                  validation_freq: int = 1,
                  quiet: bool = False,
                  decay: float = 0.000001) -> float:

        # nan-agnostic scaling values
        nan_loc = torch.isnan(inp_tensor)
        clean_tensor = inp_tensor[~nan_loc]
        means = clean_tensor.mean(dim=0, keepdim=True)
        stds = clean_tensor.std(dim=0, keepdim=True).clamp(min=0.00001)
        self._means = means.data.cpu().numpy()
        self._stds = stds.data.cpu().numpy()

        # scaling & imputing
        inp_tensor = (inp_tensor - means) / stds
        validation_batch = (validation_batch - means) / stds
        inp_tensor[nan_loc] = 0
        validation_batch[torch.isnan(validation_batch)] = 0

        self._dim = inp_tensor.size(1)
        flow_params = [
            f.generate_params(self._dim).to(self._device) for f in self._flows
        ]
        for flow, p in zip(self._flows, flow_params):
            flow.set_params(p)

        optim_params = sum([list(fp.parameters()) for fp in flow_params], [])
        optimizer = torch.optim.Adam(optim_params, lr=lr, weight_decay=decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, epochs)

        best_loss = np.inf
        progress_bar = tqdm(range(epochs), total=epochs, disable=quiet)
        for epoch in progress_bar:
            loss = self._loss(inp_tensor, random_weights=False)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(optim_params, 1)
            optimizer.step()
            scheduler.step(epoch)

            # validation
            if epoch % validation_freq == 0:
                loss = self._loss(validation_batch).item()
                if loss < best_loss:
                    best_loss = loss
                    best_flows = [
                        deepcopy(fp.state_dict()) for fp in flow_params
                    ]
                if tqdm_installed:
                    desc = "loss=%.04f;%.04f" % (loss, best_loss)
                    progress_bar.set_description(desc)

        for p, s in zip(flow_params, best_flows):
            p.load_state_dict(s)

        return best_loss
Example #14
0
def avg_precision(outputs: torch.Tensor,
                  targets: torch.Tensor) -> torch.Tensor:
    """
    Calculate the Average Precision for RecSys.
    The precision metric summarizes the fraction of relevant items
    out of the whole the recommendation list.

    To compute the precision at k set the threshold rank k,
    compute the percentage of relevant items in topK,
    ignoring the documents ranked lower than k.

    The average precision at k (AP at k) summarizes the average
    precision for relevant items up to the k-th one.
    Wikipedia entry for the Average precision

    <https://en.wikipedia.org/w/index.php?title=Information_retrieval&
    oldid=793358396#Average_precision>

    If a relevant document never gets retrieved,
    we assume the precision corresponding to that
    relevant doc to be zero

    Args:
        outputs (torch.Tensor):
            Tensor with predicted score
            size: [batch_size, slate_length]
            model outputs, logits
        targets (torch.Tensor):
            Binary tensor with ground truth.
            1 means the item is relevant
            and 0 not relevant
            size: [batch_szie, slate_length]
            ground truth, labels

    Returns:
        ap_score (torch.Tensor):
            The map score for each batch.
            size: [batch_size, 1]

    Examples:
        >>> avg_precision(
        >>>     outputs=torch.tensor([
        >>>         [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        >>>         [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
        >>>     ]),
        >>>     targets=torch.tensor([
        >>>         [1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0],
        >>>         [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
        >>>     ]),
        >>> )
        tensor([0.6222, 0.4429])
    """
    targets_sort_by_outputs = process_recsys_components(outputs, targets)
    precisions = torch.zeros_like(targets_sort_by_outputs)

    for index in range(outputs.size(1)):
        precisions[:,
                   index] = torch.sum(targets_sort_by_outputs[:, :(index + 1)],
                                      dim=1) / float(index + 1)

    only_relevant_precision = precisions * targets_sort_by_outputs
    ap_score = only_relevant_precision.sum(dim=1) / (
        (only_relevant_precision != 0).sum(dim=1))
    ap_score[torch.isnan(ap_score)] = 0
    return ap_score
Example #15
0
def train(dataset,
          max_iter,
          test_sampler,
          model,
          optimizer,
          device,
          amp,
          save=1000):

    status = Status(max_iter)
    scaler = GradScaler() if amp else None

    while status.batches_done < max_iter:
        for src in dataset:
            optimizer.zero_grad()

            src = src.to(device)

            with autocast(amp):
                # VAE(rsc)
                dst, _, mu, logvar = model(src)
                # loss
                recons_loss = recons(dst, src)
                kld = KL_divergence(mu, logvar)
                loss = recons_loss + kld

            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
            else:
                loss.backward()
                optimizer.step()

            # save
            if status.batches_done % save == 0:
                model.eval()
                with torch.no_grad():
                    images = model.decoder(test_sampler())
                model.train()
                save_image(
                    images,
                    f'implementations/VAE/result/{status.batches_done}.jpg',
                    nrow=4,
                    normalize=True,
                    range=(-1, 1))
                recons_images = _image_grid(src, dst)
                save_image(
                    recons_images,
                    f'implementations/VAE/result/recons_{status.batches_done}.jpg',
                    nrow=6,
                    normalize=True,
                    range=(-1, 1))
                torch.save(
                    model.state_dict(),
                    f'implementations/VAE/result/model_{status.batches_done}.pt'
                )

            # updates
            loss_dict = dict(
                loss=loss.item() if not torch.isnan(loss).any() else 0)
            status.update(loss_dict)
            if scaler is not None:
                scaler.update()

            if status.batches_done == max_iter:
                break
Example #16
0
    calibration_error = ece(np.asarray(probs), np.asarray(labels))
    print("Corruption {} ACC: {:.3f} ECE: {:.3f}".format(
        i + 1, val_c_score, ece(np.asarray(probs), np.asarray(labels))))
for i, batch in enumerate(ood_loader):

    if args.ood == 'svhn':
        x = batch[0].to(device)
    else:
        x = batch.to(device)

    with torch.no_grad():
        latents = ctnf.encoder(x)
        gamma, sldj = ctnf.flows(latents)
        dirichlet, logpsz = ctnf.surnorm(gamma)

    if torch.isnan(gamma).any() == True:
        print("[DEBUG] NaN value is detected in svhn_gamma")
        sys.exit()

    preds = ctnf.predict(dirichlet, sldj + logpsz)
    if torch.isnan(preds).any() == True:
        print("[DEBUG] NaN value is detected in svhn_preds")
        sys.exit()
    ood_probs.append(preds)
    preds = preds.to(device)

    if i == 3: break

ood_probs = torch.cat(ood_probs)
valid_probs = valid_probs.detach().cpu().numpy()
ood_probs = ood_probs.detach().cpu().numpy()
def primal_dual_interior_point_with_eq(x,
                                       obj,
                                       eq_cons=None,
                                       should_stop=None,
                                       u=10.0,
                                       tolerance=1e-3,
                                       constraint_tolerance=1e-3,
                                       alpha=0.1,
                                       beta=0.5,
                                       fast=False,
                                       verbose=False,
                                       duals=None):
    from torch.autograd.functional import jacobian
    from torch.autograd.functional import hessian

    n = x.size(0)
    d = eq_cons.size()

    if duals is None:
        v = x.new_ones(d)
    else:
        v = duals

    def l(x, v):
        return obj(x) + eq_cons(x) @ v

    def residual(x, v):
        r_dual = jacobian(lambda x: l(x, v), x)
        r_pri = eq_cons(x)
        return r_dual, r_pri

    def r_norm(x, v):
        r_dual, r_pri = residual(x, v)
        norm = torch.cat([r_dual, r_pri]).norm(2)
        return norm

    def jacobian_(f, x):
        if f.type() == LINEAR:
            return f.A
        else:
            return jacobian(f, x)

    def hessian_(x, v):
        if eq_cons.type() == LINEAR and obj.type() == LINEAR:
            return 0.0
        else:
            return hessian(lambda x: l(x, v), x)

    should_stop = should_stop or []
    not_improving = NotImproving()

    while True:
        r_dual, r_pri = residual(x, v)
        obj_value = obj(x)
        norm = torch.cat([r_dual, r_pri]).norm(2)

        if verbose:
            logger.info("obj:%s,r_pri:%s,r_dual:%s,norm:%s", obj_value,
                        r_pri.norm(2), r_dual.norm(2), norm)
        if r_pri.norm(2) <= constraint_tolerance and r_dual.norm(
                2) <= constraint_tolerance and norm <= tolerance:
            return x, obj_value, OPTIMAL, v

        if not_improving(norm):
            return x, obj_value, SUB_OPTIMAL, v

        if torch.isnan(obj_value):
            return x, obj_value, FAIELD, v

        h2 = hessian_(x, v)
        A = jacobian_(eq_cons, x)

        if fast and not (isinstance(h2, float) and h2 == 0.0):
            _dir_x, _dir_v = solve_kkt_fast(h2, A, r_dual, r_pri)
        else:
            _dir_x, _dir_v = solve_kkt(h2, A, r_dual, r_pri, n, d)

        step = line_search(r_norm, (x, v), (_dir_x, _dir_v),
                           norm,
                           alpha=alpha,
                           beta=beta)

        x = x + step * _dir_x
        v = v + step * _dir_v

        for ss in should_stop:
            if ss(x, obj_value, None):
                return x, obj_value, USER_STOPPED, v
Example #18
0
    def train(self, report_loss_fn=None):
        # type: (Optional[Callable[[messages.Message], None]]) -> None

        (partition_start, partition_end) = self.partition

        def report_trainer_death(idx):
            # type: (int) -> None

            if report_loss_fn is not None:
                report_loss_fn(messages.TrainerDeathMessage(
                    (idx + self.batch_size, partition_end),
                ))

        for idx in range(partition_start, partition_end, self.batch_size):
            batch_loss_sum = np.zeros(self.num_losses)
            self.correct = 0

            self.optimizer.zero_grad()
            loss_tensor = torch.FloatTensor([0]).squeeze()
            batch = self.data.train[idx:idx+self.batch_size]

            if not batch:
                continue

            for datum in batch:
                output = self.model(datum)

                if torch.isnan(output).any():
                    report_trainer_death(idx)
                    return

                #target as a tensor
                target = self.get_target(datum)

                #get the loss value
                if self.loss_fn:
                    losses_opt = self.loss_fn(output, target)

                if self.predict_log and self.loss_fn:
                    losses_rep = self.loss_fn(output.exp(), target.exp())
                else:
                    losses_rep = losses_opt

                #check how many are correct
                if self.typ == PredictionType.CLASSIFICATION:
                    self.correct_classification(output, target)
                elif self.typ == PredictionType.REGRESSION:
                    self.correct_regression(output, target)

                #accumulate the losses
                for class_idx, (loss_opt, loss_rep) in enumerate(zip(losses_opt, losses_rep)):
                    loss_tensor += loss_opt
                    l = loss_rep.item()
                    batch_loss_sum[class_idx] += l

            batch_loss_avg = batch_loss_sum / len(batch)

            #propagate gradients
            loss_tensor.backward()

            #clip the gradients
            if self.clip is not None:
                torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clip)

            for param in self.model.parameters():
                if param.grad is None:
                    continue

                if torch.isnan(param.grad).any():
                    report_trainer_death(idx)
                    return

            #optimizer step to update parameters
            self.optimizer.step()

            # get those tensors out of here!
            for datum in batch:
                self.model.remove_refs(datum)

            if report_loss_fn is not None:
                report_loss_fn(messages.LossReportMessage(
                    self.rank,
                    batch_loss_avg[0],
                    len(batch),
                ))
Example #19
0
            net.eval()
            torch.set_grad_enabled(False)


        # Store loss
        mean_volume_loss = 0 
        max_grad = 0
        mean_psnr = 0
        mean_time = 0
        mean_repro = 0
        mean_repro_ssim = 0
        # Training
        for ix,(curr_img_stack, local_volumes) in enumerate(curr_loader):

            # If empty or nan in volumes, don't use these for training 
            if curr_img_stack.float().sum()==0 or torch.isnan(curr_img_stack.float().max()):
                continue
            # Normalize volumes if ill posed
            if local_volumes.float().max()>=20000:
                local_volumes = local_volumes.float()
                local_volumes = local_volumes / local_volumes.max() * 4500.0
                local_volumes = local_volumes.half()

            # curr_img_stack returns both the dense and the sparse images, here we only need the sparse.
            if net.tempConv is None:
                assert len(curr_img_stack.shape)>=5, "If sparse is used curr_img_stack should contain both images, dense and sparse stacked in the last dim."
                curr_img_sparse = curr_img_stack[...,-1].clone().to(device) 
                curr_img_stack = curr_img_stack[...,-1].clone().to(device)
            else:
                curr_img_sparse = curr_img_stack[...,-1].clone().to(device)
                curr_img_stack = curr_img_stack[...,0].clone().to(device)
Example #20
0
def train(config, model, seg_loss_fn, optimizer, dataset_loaders):
    if is_main_process(config):
        time_str = time.strftime("%Y-%m-%d___%H-%M-%S", time.localtime())
        log_dir = os.path.join(config['log_dir'], config['net_name'],
                               config['dataset'], config['note'], time_str)
        checkpoint_path = os.path.join(log_dir,
                                       'model-last-%d.pkl' % config['epoch'])

        total_param = sum(p.numel() for p in model.parameters())
        train_param = sum(p.numel() for p in model.parameters()
                          if p.requires_grad)
        config.total_param = total_param / (1024 * 1024)
        config.train_param = train_param / (1024 * 1024)
        writer = init_writer(config, log_dir)

    motionseg_metric = MotionSegMetric(config.exception_value)

    if is_main_process(config):
        tqdm_epoch = trange(config['epoch'],
                            desc='{} epochs'.format(config.note),
                            leave=True)
    else:
        tqdm_epoch = range(config.epoch)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    step_acc = 0
    for epoch in tqdm_epoch:
        for split in ['train', 'val']:
            if split == 'train':
                model.train()
            else:
                model.eval()

            motionseg_metric.reset()

            if is_main_process(config):
                tqdm_step = tqdm(dataset_loaders[split],
                                 desc='steps',
                                 leave=False)
            else:
                tqdm_step = dataset_loaders[split]

            total_time = 0
            counter = 0
            N = len(dataset_loaders[split])
            for step, data in enumerate(tqdm_step):
                images, origin_labels, resize_labels = prepare_input_output(
                    data=data, config=config, device=device)

                if split == 'train':
                    poly_lr_scheduler(config,
                                      optimizer,
                                      iter=epoch * N + step,
                                      max_iter=config.epoch * N)

                if config.net_name.startswith('motion'):
                    start_time = time.time()
                    outputs = model.forward(images)
                    total_time += (time.time() - start_time)
                    counter += images[0].shape[0]
                else:
                    #assert config.input_format=='n'
                    start_time = time.time()
                    outputs = model.forward(torch.cat(images, dim=1))
                    total_time += (time.time() - start_time)
                    counter += images[0].shape[0]

                if config.net_name == 'motion_anet':
                    mask_gt = torch.squeeze(resize_labels[0], dim=1)
                    mask_loss_value = 0
                    for mask in outputs['masks']:
                        mask_loss_value += seg_loss_fn(mask, mask_gt)
                elif config.net_name == 'motion_diff' or not config.net_name.startswith(
                        'motion'):
                    gt_plus = (resize_labels[0] -
                               resize_labels[1]).clamp_(min=0).float()
                    gt_minus = (resize_labels[1] -
                                resize_labels[0]).clamp_(min=0).float()
                    mask_gt = torch.cat(
                        [gt_plus, gt_minus, resize_labels[0].float()], dim=1)
                    ignore_index = 255

                    if config.net_name == 'motion_diff':
                        predict = outputs['masks'][0]
                    else:
                        predict = outputs
                    predict[mask_gt == ignore_index] = 0
                    mask_gt[mask_gt == ignore_index] = 0
                    mask_loss_value = seg_loss_fn(predict.float(),
                                                  mask_gt.float())
                else:
                    mask_loss_value = seg_loss_fn(
                        outputs['masks'][0],
                        torch.squeeze(resize_labels[0], dim=1))

                if config['net_name'].find('_stn') >= 0:
                    if config['stn_object'] == 'features':
                        stn_loss_value = stn_loss(outputs['features'],
                                                  resize_labels[0].float(),
                                                  outputs['pose'],
                                                  config['pose_mask_reg'])
                    elif config['stn_object'] == 'images':
                        stn_loss_value = stn_loss(outputs['stn_images'],
                                                  resize_labels[0].float(),
                                                  outputs['pose'],
                                                  config['pose_mask_reg'])
                    else:
                        assert False, 'unknown stn object %s' % config[
                            'stn_object']

                    total_loss_value = mask_loss_value * config[
                        'motion_loss_weight'] + stn_loss_value * config[
                            'stn_loss_weight']
                else:
                    stn_loss_value = torch.tensor(0.0)
                    total_loss_value = mask_loss_value

                #assert not torch.isnan(total_loss_value),'find nan loss'
                if torch.isnan(total_loss_value) or torch.isinf(
                        total_loss_value):
                    raise RuntimeError("find nan or inf loss")

                if config.net_name == 'motion_diff' or not config.net_name.startswith(
                        'motion'):
                    if config.net_name == 'motion_diff':
                        predict = outputs['masks'][0]
                    else:
                        predict = outputs

                    #predict[:,2:3,:,:]=predict[:,2:3,:,:]+predict[:,0:1,:,:]-predict[:,1:2,:,:]
                    origin_mask = F.interpolate(
                        predict[:, 2:3, :, :],
                        size=origin_labels[0].shape[2:4],
                        mode='bilinear')
                    origin_mask = torch.cat([1 - origin_mask, origin_mask],
                                            dim=1)
                else:
                    origin_mask = F.interpolate(
                        outputs['masks'][0],
                        size=origin_labels[0].shape[2:4],
                        mode='bilinear')

                motionseg_metric.update({
                    "fmeasure": (origin_mask, origin_labels[0]),
                    "stn_loss":
                    stn_loss_value.item(),
                    "mask_loss":
                    mask_loss_value.item(),
                    "total_loss":
                    total_loss_value.item()
                })

                if split == 'train':
                    total_loss_value.backward()
                    if (step_acc + 1) >= config.accumulate:
                        optimizer.step()
                        optimizer.zero_grad()
                        step_acc = 0
                    else:
                        step_acc += 1

            if is_main_process(config):
                fps = counter / total_time
                writer.add_scalar(split + '/fps', fps, epoch)
                motionseg_metric.write(writer, split, epoch)
                current_metric = motionseg_metric.fetch()
                fmeasure = current_metric['fmeasure'].item()
                mean_total_loss = current_metric['total_loss']
                if split == 'train':
                    tqdm_epoch.set_postfix(train_fmeasure=fmeasure)
                else:
                    tqdm_epoch.set_postfix(val_fmeasure=fmeasure)

                if epoch % 10 == 0:
                    print(split, 'fmeasure=%0.4f' % fmeasure, 'total_loss=',
                          mean_total_loss)

    if is_main_process(config) and config['save_model']:
        torch.save(model.state_dict(), checkpoint_path)

    if is_main_process(config):
        writer.close()
Example #21
0
    def torch_likelihood(self,
                         tensor: torch.Tensor,
                         log_output: bool = True,
                         normalized: bool = False,
                         imputing: bool = True) -> torch.Tensor:
        """ Compute the likelihood or loglikehood of the given data points.

        Parameters
        ----------
        tensor: torch.Tensor of dimension (n_samples, n_features)
            Data points for which we want the likelihood.
            It is your responsibility to move the tensor to the desired device,
            and to split it if computations wouldn't fit on the device.
        log_output : bool
            Whether to return logprobabilities or probabilities.
        normalized: bool
            Whether outputs should be probabilities or plain scores.
            Unnormalized scores are faster and just as good for anomaly detection.
        imputing : bool
            Whether to replace NaNs with mean values.
            Set imputing to false if you are sure there are no NaNs values.

        Returns
        -------
        torch.Tensor of dimension (n_samples, )
            Likelihood of the input. Stored on same device as the input.
        """

        if not self._ready_for_prediction:
            for flow in self._flows:
                flow.move_to_device(
                    self._device)  # this includes reparametrization
            self._ready_for_prediction = True
            self._torch_means = torch.from_numpy(self._means).to(self._device)
            self._torch_stds = torch.from_numpy(self._stds).to(self._device)

        tensor -= self._torch_means
        tensor /= self._torch_stds

        if imputing:
            tensor[torch.isnan(tensor)] = 0

        # inverse the flow
        ys = []
        y = tensor
        for flow in reversed(self._flows):
            y = flow.inverse(y)
            ys.append(y)

        # loglikehood loss
        ll = -y.pow(2).sum(dim=1) / 2  # independent normal distributions
        for f, z in zip(self._flows, reversed(ys)):
            # the minus comes from moving the ^-1 when applying the logarithm
            ll -= f.log_abs_jac_det(z)

        if not normalized:
            ll += GaussNormalizer

        if not log_output:
            ll = ll.exp()

        return ll
Example #22
0
    def prepare(self, priming_data, previous_target_data=None, feedback_hoop_function=None, batch_size=256):
        """
        The usual, run this on the initial training data for the encoder
        :param priming_data: a list of (self._n_dims)-dimensional time series [[dim1_data], ...]
        :param previous_target_data: tensor with encoded previous target values for autoregressive tasks
        :param feedback_hoop_function: [if you want to get feedback on the training process]
        :param batch_size
        :return:
        """
        if self._prepared:
            raise Exception('You can only call "prepare" once for a given encoder.')
        else:
            self.setup_nn(previous_target_data)

        # Convert to array and determine max length
        priming_data, lengths_data = self._prepare_raw_data(priming_data)
        self._max_ts_length = int(lengths_data.max())

        if self._normalizer:
            self._normalizer.prepare(priming_data)
            priming_data = torch.stack([self._normalizer.encode(d) for d in priming_data]).to(self.device)
        else:
            priming_data = torch.stack([d for d in priming_data]).unsqueeze(-1).to(self.device)

        # merge all normalized data into a training batch
        if previous_target_data is not None and len(previous_target_data) > 0:
            normalized_tensors = []
            for target_dict in previous_target_data:
                normalizer = target_dict['normalizer']
                self._target_ar_normalizers.append(normalizer)
                data = torch.Tensor(normalizer.encode(target_dict['data'])).to(self.device)
                data[torch.isnan(data)] = 0.0
                if len(data.shape) < 3:
                    data = data.unsqueeze(-1)  # add feature dimension
                normalized_tensors.append(data)

            normalized_data = torch.cat(normalized_tensors, dim=-1)
            priming_data = torch.cat([priming_data, normalized_data], dim=-1)

        self._encoder.train()
        for i in range(self._epochs):
            average_loss = 0

            for batch_idx in range(0, len(priming_data), batch_size):
                # setup loss and optimizer
                self._optimizer.zero_grad()
                loss = 0

                # shape: (batch_size, timesteps, n_dims)
                batch = self._get_batch(priming_data, batch_idx, min(batch_idx + batch_size, len(priming_data)))

                # encode and decode through time
                with LightwoodAutocast():
                    if self.encoder_class == TransformerEncoder:
                        # pack batch length info tensor
                        len_batch = self._get_batch(lengths_data, batch_idx, min(batch_idx + batch_size, len(priming_data)))
                        batch = batch, len_batch

                        next_tensor, hidden_state, dec_loss = self._encoder.bptt(batch, self._enc_criterion, self.device)
                        loss += dec_loss

                    else:
                        next_tensor, hidden_state, enc_loss = self._encoder.bptt(batch, self._enc_criterion, self.device)
                        loss += enc_loss

                        next_tensor, hidden_state, dec_loss = self._decoder.decode(batch, next_tensor, self._dec_criterion,
                                                                                   self.device,
                                                                                   hidden_state=hidden_state)
                        loss += dec_loss

                loss.backward()

                self._optimizer.step()
                average_loss += loss.item()

            average_loss = average_loss / len(priming_data)
            batch_idx += batch_size

            if average_loss < self._stop_on_error:
                break
            if feedback_hoop_function is not None:
                feedback_hoop_function("epoch [{epoch_n}/{total}] average_loss = {average_loss}".format(
                    epoch_n=i + 1,
                    total=self._epochs,
                    average_loss=average_loss))

        self._prepared = True
Example #23
0
 def backward(self, grad_output):
     if torch.isnan(grad_output).any():
         return grad_output.zero_()
     else:
         return (grad_output * -self.lambd)
Example #24
0
File: prme.py Project: zan12/prme
 def fit_v_ell_inference_decoder(self, x_bow, N_total, z_a, z_b, global_iter):
   optimizer_v_ell_inference_decoder = optim.Adam([
       {'params': self.v},
       {'params': self.ell},
       {'params': self.inference_network.parameters()},
       {'params': self.decoder_network.parameters()}
       ], lr=self.lr)
   N = N_total
   network_iter = self.inner_iter
   prev_loss = 0   
   for iter_v_ell_inference_decoder in range(network_iter):
     optimizer_v_ell_inference_decoder.zero_grad()
     h = self.inference_network(x_bow)
     hl = torch.cat((h.repeat(1,self.K).view(N*self.K, self.D_h),
         self.ell.repeat(N,1)), 1)
     decoder_output1, decoder_output2 = self.decoder_network(hl)
     decoder_mu_theta = decoder_output1.view(N, self.K)
     decoder_sigma2_theta = decoder_output2.view(N, self.K)
     decoder_sigma2_theta.data.clamp_(min=1e-6, max=100)
     ln_p_v = (self.alpha0-1)*torch.sum(torch.log(1-self.v)) + CONST
     ln_p_k = torch.cat((
         torch.log(self.v),torch.ones(1).to(self.device)), 0) + torch.cat((
         torch.zeros(1).to(self.device),torch.cumsum(
         torch.log(1-self.v), dim=0)), 0)
     E_ln_z = torch.digamma(z_a) + torch.log(z_b)
     E_ln_p_z = -N*torch.sum(torch.lgamma(
       self.beta*torch.exp(ln_p_k))) - self.beta*torch.dot(torch.exp(ln_p_k),
       torch.sum(decoder_mu_theta-E_ln_z, dim=0)) - torch.sum(
       E_ln_z) - torch.sum(z_a*z_b*torch.exp(
       -decoder_mu_theta+decoder_sigma2_theta/2))
     E_ln_p_h = -N*self.D_h/2*torch.log(
         torch.tensor(2*np.pi*self.a0).to(self.device))-1/2/self.a0*torch.sum(
         h.pow(2))
     ln_p_ell = -self.K*self.D_ell/2*torch.log(
         2*np.pi*torch.tensor(self.b0).to(self.device)) - torch.norm(
         self.ell).pow(2)/2/self.b0
     network_norm = 0
     for param in self.inference_network.parameters():
       network_norm += torch.norm(param)
     for param in self.decoder_network.parameters():
       network_norm += torch.norm(param)
     net_norm = self.network_reg*network_norm
     loss = -ln_p_v-N_total/N*E_ln_p_z-N_total/N*E_ln_p_h-ln_p_ell+net_norm
     loss.backward()
     optimizer_v_ell_inference_decoder.step()
     self.v.data.clamp_(min=1e-6, max=1-1e-6)
     print(iter_v_ell_inference_decoder, loss)
     if torch.isnan(loss.data):
       print(h,
             ln_p_v,
             ln_p_k,
             E_ln_p_z,
             E_ln_p_h,
             ln_p_ell,
             network_norm)
       raise ValueError('Nan loss!')
     if (torch.abs((prev_loss-loss)/loss) <= 1e-6 and 
         iter_v_ell_inference_decoder>=50) or (iter_v_ell_inference_decoder
         == network_iter-1):
       break
     prev_loss = loss
Example #25
0
    def train(self):
        # Single epoch training routine

        losses = AverageMeter()

        timer = {
            'data': 0,
            'forward': 0,
            'loss': 0,
            'backward': 0,
            'batch': 0,
        }

        self.generator.train()
        self.motion_discriminator.train()

        start = time.time()

        summary_string = ''

        bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}',
                  fill='#',
                  max=self.num_iters_per_epoch)

        for i in range(self.num_iters_per_epoch):
            # Dirty solution to reset an iterator
            target_2d = target_3d = None
            if self.train_2d_iter:
                try:
                    target_2d = next(self.train_2d_iter)
                except StopIteration:
                    self.train_2d_iter = iter(self.train_2d_loader)
                    target_2d = next(self.train_2d_iter)

                move_dict_to_device(target_2d, self.device)

            if self.train_3d_iter:
                try:
                    target_3d = next(self.train_3d_iter)
                except StopIteration:
                    self.train_3d_iter = iter(self.train_3d_loader)
                    target_3d = next(self.train_3d_iter)

                move_dict_to_device(target_3d, self.device)

            real_body_samples = real_motion_samples = None

            try:
                real_motion_samples = next(self.disc_motion_iter)
            except StopIteration:
                self.disc_motion_iter = iter(self.disc_motion_loader)
                real_motion_samples = next(self.disc_motion_iter)

            move_dict_to_device(real_motion_samples, self.device)

            # <======= Feedforward generator and discriminator
            if target_2d and target_3d:
                inp = torch.cat((target_2d['features'], target_3d['features']),
                                dim=0).to(self.device)
            elif target_3d:
                inp = target_3d['features'].to(self.device)
            else:
                inp = target_2d['features'].to(self.device)

            timer['data'] = time.time() - start
            start = time.time()

            preds = self.generator(inp)

            timer['forward'] = time.time() - start
            start = time.time()

            gen_loss, motion_dis_loss, loss_dict = self.criterion(
                generator_outputs=preds,
                data_2d=target_2d,
                data_3d=target_3d,
                data_body_mosh=real_body_samples,
                data_motion_mosh=real_motion_samples,
                motion_discriminator=self.motion_discriminator,
            )
            # =======>

            timer['loss'] = time.time() - start
            start = time.time()

            # <======= Backprop generator and discriminator
            self.gen_optimizer.zero_grad()
            gen_loss.backward()
            self.gen_optimizer.step()

            if self.train_global_step % self.dis_motion_update_steps == 0:
                self.dis_motion_optimizer.zero_grad()
                motion_dis_loss.backward()
                self.dis_motion_optimizer.step()
            # =======>

            # <======= Log training info
            total_loss = gen_loss + motion_dis_loss

            losses.update(total_loss.item(), inp.size(0))

            timer['backward'] = time.time() - start
            timer['batch'] = timer['data'] + timer['forward'] + timer[
                'loss'] + timer['backward']
            start = time.time()

            summary_string = f'({i + 1}/{self.num_iters_per_epoch}) | Total: {bar.elapsed_td} | ' \
                             f'ETA: {bar.eta_td:} | loss: {losses.avg:.4f}'

            for k, v in loss_dict.items():
                summary_string += f' | {k}: {v:.2f}'
                self.writer.add_scalar('train_loss/' + k,
                                       v,
                                       global_step=self.train_global_step)

            for k, v in timer.items():
                summary_string += f' | {k}: {v:.2f}'

            self.writer.add_scalar('train_loss/loss',
                                   total_loss.item(),
                                   global_step=self.train_global_step)

            if self.debug:
                print('==== Visualize ====')
                from lib.utils.vis import batch_visualize_vid_preds
                video = target_3d['video']
                dataset = 'spin'
                vid_tensor = batch_visualize_vid_preds(video,
                                                       preds[-1],
                                                       target_3d.copy(),
                                                       vis_hmr=False,
                                                       dataset=dataset)
                self.writer.add_video('train-video',
                                      vid_tensor,
                                      global_step=self.train_global_step,
                                      fps=10)

            self.train_global_step += 1
            bar.suffix = summary_string
            bar.next()

            if torch.isnan(total_loss):
                exit('Nan value in loss, exiting!...')
            # =======>

        bar.finish()

        logger.info(summary_string)
Example #26
0
def train(task_class, model, train_data, num_epochs, lr, device, dev_data=None,
          cert_frac=0.0, initial_cert_frac=0.0, cert_eps=1.0, initial_cert_eps=0.0, non_cert_train_epochs=0, full_train_epochs=0,
          batch_size=1, epochs_per_save=1, augmenter=None, clip_grad_norm=0, weight_decay=0,
          save_best_only=False):
    print('Training model')
    sys.stdout.flush()
    loss_func = task_class.LOSS_FUNC
    optimizer = torch.optim.Adam(
        model.parameters(), lr=lr, weight_decay=weight_decay)
    zero_stats = {'epoch': 0, 'clean_acc': 0.0, 'cert_acc': 0.0}
    if augmenter:
        zero_stats['aug_acc'] = 0.0
    all_epoch_stats = {
        "loss": {"total": [],
                 "clean": [],
                 "cert": []},
        "cert": {"frac": [],
                 "eps": []},
        "acc": {
            "train": {
                "clean": [],
                "cert": []},
            "dev": {
                "clean": [],
                "cert": []},
            "best_dev": {
                "clean": [zero_stats],
                "cert": [zero_stats]}},
        "total_epochs": num_epochs,
    }
    aug_dev_data = None
    if augmenter:
        all_epoch_stats['acc']['dev']['aug'] = []
        all_epoch_stats['acc']['best_dev']['aug'] = [zero_stats]
        print('Augmenting training data')
        aug_train_data = augmenter.augment(train_data)
        data = aug_train_data.get_loader(batch_size)
        if dev_data:
            print('Augmenting dev data')
            # Augment dev set now, for early stopping
            aug_dev_data = augmenter.augment(dev_data)
    else:
        # Create all batches now and pin them in memory
        data = train_data.get_loader(batch_size)
    # Linearly increase the weight of adversarial loss over all the epochs to end up at the final desired fraction
    cert_schedule = torch.tensor(np.linspace(initial_cert_frac, cert_frac, num_epochs -
                                             full_train_epochs - non_cert_train_epochs), dtype=torch.float, device=device)
    eps_schedule = torch.tensor(np.linspace(initial_cert_eps, cert_eps, num_epochs -
                                            full_train_epochs - non_cert_train_epochs), dtype=torch.float, device=device)
    for t in range(num_epochs):
        model.train()
        if t < non_cert_train_epochs:
            cur_cert_frac = 0.0
            cur_cert_eps = 0.0
        else:
            cur_cert_frac = cert_schedule[t - non_cert_train_epochs] if t - \
                non_cert_train_epochs < len(cert_schedule) else cert_schedule[-1]
            cur_cert_eps = eps_schedule[t - non_cert_train_epochs] if t - \
                non_cert_train_epochs < len(eps_schedule) else eps_schedule[-1]
        epoch = {
            "total_loss": 0.0,
            "clean_loss": 0.0,
            "cert_loss": 0.0,
            "num_correct": 0,
            "num_cert_correct": 0,
            "num": 0,
            "clean_acc": 0,
            "cert_acc": 0,
            "dev": {},
            "best_dev": {},
            "cert_frac": cur_cert_frac if isinstance(cur_cert_frac, float) else cur_cert_frac.item(),
            "cert_eps": cur_cert_eps if isinstance(cur_cert_eps, float) else cur_cert_eps.item(),
            "epoch": t,
        }
        with tqdm(data) as batch_loop:
            for batch_idx, batch in enumerate(batch_loop):
                batch = data_util.dict_batch_to_device(batch, device)
                optimizer.zero_grad()
                if cur_cert_frac > 0.0:
                    out = model.forward(batch, cert_eps=cur_cert_eps)
                    logits = out.val
                    loss = loss_func(logits, batch['y'])
                    epoch["clean_loss"] += loss.item()
                    cert_loss = torch.max(loss_func(out.lb, batch['y']),
                                          loss_func(out.ub, batch['y']))
                    loss = cur_cert_frac * cert_loss + \
                        (1.0 - cur_cert_frac) * loss
                    epoch["cert_loss"] += cert_loss.item()
                else:
                    # Bypass computing bounds during training
                    logits = out = model.forward(batch, compute_bounds=False)
                    loss = loss_func(logits, batch['y'])
                epoch["total_loss"] += loss.item()
                epoch["num"] += len(batch['y'])
                num_correct, num_cert_correct = task_class.num_correct(
                    out, batch['y'])
                epoch["num_correct"] += num_correct
                epoch["num_cert_correct"] += num_cert_correct
                loss.backward()
                if any(p.grad is not None and torch.isnan(p.grad).any() for p in model.parameters()):
                    nan_params = [p.name for p in model.parameters(
                    ) if p.grad is not None and torch.isnan(p.grad).any()]
                    print('NaN found in gradients: %s' %
                          nan_params, file=sys.stderr)
                else:
                    if clip_grad_norm:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), clip_grad_norm)
                    optimizer.step()
            if cert_frac > 0.0:
                print("Epoch {epoch:>3}: train loss: {total_loss:.6f}, clean_loss: {clean_loss:.6f}, cert_loss: {cert_loss:.6f}".format(
                    **epoch))
            else:
                print(
                    "Epoch {epoch:>3}: train loss: {total_loss:.6f}".format(**epoch))
            sys.stdout.flush()

        epoch["clean_acc"] = 100.0 * epoch["num_correct"] / epoch["num"]
        acc_str = "  Train accuracy: {num_correct}/{num} = {clean_acc:.2f}".format(
            **epoch)
        if cert_frac > 0.0:
            epoch["cert_acc"] = 100.0 * \
                epoch["num_cert_correct"] / epoch["num"]
            acc_str += ", certified {num_cert_correct}/{num} = {cert_acc:.2f}".format(
                **epoch)
        print(acc_str)
        is_best = False
        if dev_data:
            dev_results = test(task_class, model, "Dev", dev_data, device, batch_size=batch_size,
                               aug_dataset=aug_dev_data)
            epoch['dev'] = dev_results
            all_epoch_stats['acc']['dev']['clean'].append(
                dev_results['clean_acc'])
            all_epoch_stats['acc']['dev']['cert'].append(
                dev_results['cert_acc'])
            if augmenter:
                all_epoch_stats['acc']['dev']['aug'].append(
                    dev_results['aug_acc'])
            dev_stats = {
                'epoch': t,
                'loss': dev_results['loss'],
                'clean_acc': dev_results['clean_acc'],
                'cert_acc': dev_results['cert_acc']
            }
            if augmenter:
                dev_stats['aug_acc'] = dev_results['aug_acc']
            if dev_results['clean_acc'] > all_epoch_stats['acc']['best_dev']['clean'][-1]['clean_acc']:
                all_epoch_stats['acc']['best_dev']['clean'].append(dev_stats)
                if cert_frac == 0.0 and not augmenter:
                    is_best = True
            if dev_results['cert_acc'] > all_epoch_stats['acc']['best_dev']['cert'][-1]['cert_acc']:
                all_epoch_stats['acc']['best_dev']['cert'].append(dev_stats)
                if cert_frac > 0.0:
                    is_best = True
            if augmenter and dev_results['aug_acc'] > all_epoch_stats['acc']['best_dev']['aug'][-1]['aug_acc']:
                all_epoch_stats['acc']['best_dev']['aug'].append(dev_stats)
                if cert_frac == 0.0 and augmenter:
                    is_best = True
            epoch['best_dev'] = {
                'clean': all_epoch_stats['acc']['best_dev']['clean'][-1],
                'cert': all_epoch_stats['acc']['best_dev']['cert'][-1]}
            if augmenter:
                epoch['best_dev']['aug'] = all_epoch_stats['acc']['best_dev']['aug'][-1]
        all_epoch_stats["loss"]['total'].append(epoch["total_loss"])
        all_epoch_stats["loss"]['clean'].append(epoch["clean_loss"])
        all_epoch_stats["loss"]['cert'].append(epoch["cert_loss"])
        all_epoch_stats['cert']['frac'].append(epoch["cert_frac"])
        all_epoch_stats['cert']['eps'].append(epoch["cert_eps"])
        all_epoch_stats["acc"]['train']['clean'].append(epoch["clean_acc"])
        all_epoch_stats["acc"]['train']['cert'].append(epoch["cert_acc"])
        with open(os.path.join(OPTS.out_dir, "run_stats.json"), "w") as outfile:
            json.dump(epoch, outfile)
        with open(os.path.join(OPTS.out_dir, "all_epoch_stats.json"), "w") as outfile:
            json.dump(all_epoch_stats, outfile)
        if ((save_best_only and is_best)
            or (not save_best_only and epochs_per_save and (t+1) % epochs_per_save == 0)
                or t == num_epochs - 1):
            if save_best_only and is_best:
                for fn in glob.glob(os.path.join(OPTS.out_dir, 'model-checkpoint*.pth')):
                    os.remove(fn)
            model_save_path = os.path.join(
                OPTS.out_dir, "model-checkpoint-{}.pth".format(t))
            print('Saving model to %s' % model_save_path)
            torch.save(model.state_dict(), model_save_path)

    return model
Example #27
0
def compute_losses(model, data, losses_tracking, add_transformation_loss=True):
    losses = []

    # joint embedding loss
    imgs = np.stack([d['image'] for d in data])
    imgs = torch.from_numpy(imgs).float()
    if len(imgs.shape) == 2:
        imgs = model.img_encoder.fc(imgs.cuda())
    else:
        imgs = model.img_encoder(imgs.cuda())
    texts = [random.choice(d['captions']) for d in data]
    texts = model.text_encoder(texts)
    loss_name = 'joint_embedding'
    loss_weight = 1.0
    loss_value = model.pair_loss(texts, imgs).cuda()
    losses += [(loss_name, loss_weight, loss_value)]

    # transformation loss
    if add_transformation_loss:
        indices, source_texts, target_texts, replace_word = sample_word_pairs(
            [d['captions'] for d in data])
        target_imgs = [imgs[i, :] for i in indices]
        target_imgs = torch.stack(target_imgs)
        source_words = [i[0] for i in replace_word]
        target_words = [i[1] for i in replace_word]

        source_texts = model.text_encoder(source_texts).detach()
        source_words = model.text_encoder(source_words).detach()
        target_texts = model.text_encoder(target_texts).detach()
        target_words = model.text_encoder(target_words).detach()
        target_imgs = target_imgs.detach()

        source_texts_to_target = model.transformer(
            (source_texts, source_words, target_words))
        target_texts_to_source = model.transformer(
            (target_texts, target_words, source_words))
        target_imgs_to_source = model.transformer(
            (target_imgs, target_words, source_words))
        target_imgs_to_source_to_target = model.transformer(
            (target_imgs_to_source, source_words, target_words))
        pairs = [
            # sources (no text dups):
            # (1) source_texts
            # (2) target_texts_to_source
            # (3) target_imgs_to_source
            #(source_texts, target_texts_to_source),
            (target_imgs_to_source, source_texts),
            #(target_imgs_to_source, target_texts_to_source),

            # targets (dups!):
            # (1) target_texts
            # (2) source_texts_to_target
            # (3) target_imgs
            # (4) target_imgs_to_source_to_target
            #(target_texts, target_imgs),
            #(target_texts, source_texts_to_target),
            (target_imgs, source_texts_to_target),
            #(target_imgs_to_source_to_target, target_imgs),

            # combination
            (torch.cat((source_texts_to_target, target_texts_to_source)),
             torch.cat((target_texts, source_texts))),
            (torch.cat((source_texts_to_target, target_imgs_to_source)),
             torch.cat((target_imgs, source_texts))),
            (torch.cat((target_imgs, target_imgs_to_source)),
             torch.cat((target_texts, source_texts))),
        ]
        for i, p in enumerate(pairs):
            loss_value = model.pair_loss(p[0], p[1])
            loss_name = 'loss_transformation' + str(i + 1)
            loss_weight = 1.0 / len(pairs)
            losses += [(loss_name, loss_weight, loss_value)]

    # total loss
    total_loss = sum([
        loss_weight * loss_value
        for loss_name, loss_weight, loss_value in losses
    ])
    assert (not torch.isnan(total_loss))
    losses += [('total training loss', None, total_loss)]

    # save losses
    for loss_name, loss_weight, loss_value in losses:
        if not losses_tracking.has_key(loss_name):
            losses_tracking[loss_name] = []
        losses_tracking[loss_name].append(float(loss_value.data.item()))
    return total_loss
Example #28
0
def train(train_loader, test_loader, eval_points=None, epochs=5, model_type='GRU', layers=4, hidden=100, output='grasp',
          drop_prob=0.2, input_dim=None, train_points=0):
    # Set hyperparameters
    if input_dim == None:
        input_dim = 51
    if output == 'grasp':
        output_dim = 1
        loss_fn = nn.MSELoss()
        network_type = 1
    elif output == 'slip':
        output_dim = 1
        loss_fn = nn.MSELoss()
        network_type = 1
    elif output == 'contact':
        output_dim = 6
        loss_fn = nn.MSELoss()
        network_type = 2
    elif output == 'drop':
        output_dim = 1
        loss_fn = nn.MSELoss()
        network_type = 1
    batch_size = 5000
    # Instantiate the models
    if model_type == 'GRU':
        model = GRUNet(input_dim, hidden, output_dim, layers, drop_prob)
        model_copy = copy.deepcopy(model)
        backup_acc = 0

    elif model_type == 'LSTM':
        model = LSTMNet(input_dim, hidden, output_dim, layers, drop_prob)
        model_copy = copy.deepcopy(model)
        backup_acc = 0
    # Define loss function and optimizer
    if torch.cuda.is_available():
        model.cuda()
    optim = torch.optim.Adam(model.parameters(), lr=0.0001)
    model.train()
    accs = []
    TP_rate = []
    FP_rate = []
    losses = []
    steps = []
    print('starting training, finding the starting accuracy for random model of type', output)
    acc, TP, FP = evaluate(model, test_loader, eval_points, network_type, input_dim)
    print(f'starting: accuracy - {acc}, TP rate - {TP}, FP rate - {FP}')
    accs.append(acc)
    TP_rate.append(TP)
    FP_rate.append(FP)
    losses.append(0)
    steps.append(0)
    net_loss = 0
    for epoch in range(1, epochs + 1):
        hiddens = model.init_hidden(batch_size)
        net_loss = 0
        epoch_loss = 0
        step = 0
        for x, label in train_loader:
            x = torch.reshape(x, (5000, 1, input_dim))
            if model_type == "GRU":
                hiddens = hiddens.data
            else:
                hiddens = tuple([e.data for e in hiddens])
            pred, hiddens = model(x.to(device).float(), hiddens)
            for param in model.parameters():
                if torch.isnan(param).any():
                    print('shit went sideways')
            if network_type == 1:
                pred = torch.reshape(pred, (5000,))
                loss = loss_fn(pred, label.to(device).float())
            else:
                loss = loss_fn(pred.to('cpu'), label.to('cpu').float())
            optim.zero_grad()
            loss.backward()
            optim.step()
            net_loss += float(loss)
            epoch_loss += float(loss)
            step += 1
        acc, TP, FP = evaluate(model, test_loader, eval_points, network_type, input_dim)
        print(f'epoch {epoch}: accuracy - {acc}, loss - {epoch_loss}, TP rate - {TP}, FP rate - {FP}')
        if acc > backup_acc:
            model_copy = copy.deepcopy(model)
            backup_acc = acc
        accs.append(acc)
        losses.append(net_loss)
        steps.append(epoch)
        TP_rate.append(TP)
        FP_rate.append(FP)
        net_loss = 0
    print(f'returning best recorded model with acc = {backup_acc}')
    return model_copy, accs, losses, steps, TP_rate, FP_rate
Example #29
0
    def attack_cw(self,
                  label,
                  out_dir=None,
                  save_title=None,
                  steps=5,
                  vertex_lr=0.001,
                  pose_lr=0.05,
                  lighting_lr=8000,
                  vertex_attack=True,
                  pose_attack=True,
                  lighting_attack=False,
                  target=None):

        if out_dir is not None and save_title is None:
            raise Exception("Must provide image title if out dir is provided")
        elif save_title is not None and out_dir is None:
            raise Exception("Must provide directory if image is to be saved")

        filename = save_title

        # classify
        img = self.render_image(out_dir=out_dir, filename=filename)

        if target is not None:
            target = torch.tensor([target]).to(pyredner.get_device())
            self.targeted = True
        else:
            target = torch.tensor([label]).to(pyredner.get_device())

        target_onehot = torch.zeros(target.size() + (self.NUM_CLASSES, )).to(
            pyredner.get_device())
        target_onehot.scatter_(1, target.unsqueeze(1), 1.)

        # only there to zero out gradients.
        optimizer = torch.optim.Adam([
            self.translation, self.euler_angles_modifier, self.light_modifier
        ] + [m for m in self.modifiers],
                                     lr=0)

        for i in range(steps):
            optimizer.zero_grad()
            pred, net_out = self.classify(img)
            if pred.item() != label and i != 0:
                final_image = np.clip(
                    img[0].permute(1, 2, 0).data.cpu().numpy(), 0, 1)
                return pred, final_image

            loss = 0
            if vertex_attack:
                dist = l2_dist(self.input_adv_list, self.input_orig_list,
                               False)
                loss += self.cw_loss(net_out, target_onehot, dist, 0.1)

            if pose_attack:
                dist = l2_dist(self.angle_input_adv_list,
                               self.angle_input_orig_list, False)
                loss += self.cw_loss(net_out, target_onehot, dist, 0.1)

            if lighting_attack:
                dist = l2_dist(self.light_input_adv_list,
                               self.light_input_orig_list, False)
                loss += self.cw_loss(net_out, target_onehot, dist, 0.1)

            # get gradients
            loss.backward(retain_graph=True)

            delta = 1e-6
            inf_count = 0
            nan_count = 0

            if vertex_attack:
                # attack each shape's vertices
                self.input_orig_list = []
                self.input_adv_list = []

                for shape, m in zip(self.shapes, self.modifiers):
                    shape.vertices = tanh_rescale(
                        torch_arctanh(shape.vertices.clone().detach()) -
                        m.clone().detach())
                    if not torch.isfinite(m.grad).all():
                        inf_count += 1
                    elif torch.isnan(m.grad).any():
                        nan_count += 1
                    else:
                        # subtract because we are trying to decrease the classification score of the label
                        m.data -= m.grad / (torch.norm(m.grad) +
                                            delta) * vertex_lr

                for shape, m in zip(self.shapes, self.modifiers):
                    self.input_orig_list.append(
                        tanh_rescale(torch_arctanh(shape.vertices)))
                    shape.vertices = tanh_rescale(
                        torch_arctanh(shape.vertices) + m)

                    self.input_adv_list.append(shape.vertices)

            if lighting_attack:
                self.light_input_orig_list = []
                self.light_input_adv_list = []
                # tanh_rescale(torch_arctanh(self.light_init_vals/torch.norm(self.light_init_vals)) + self.light_modifier/torch.norm(self.light_modifier + delta))
                tanh_factor = tanh_rescale(
                    torch_arctanh(
                        self.light_intensity.clone().detach() /
                        torch.norm(self.light_intensity.clone().detach())) -
                    self.light_modifier.clone().detach() /
                    torch.norm(self.light_modifier.clone().detach() + delta))
                self.light_init_vals = torch.norm(
                    self.light_intensity.clone().detach()) * torch.clamp(
                        tanh_factor, 0, 1)

                self.light_modifier.data -= self.light_modifier.grad / (
                    torch.norm(self.light_modifier.grad) + delta) * lighting_lr

                # redner can't accept negative light intensities, so we have to be a bit creative and work with lighting norms instead and then rescale them afterwards...
                tanh_factor = tanh_rescale(
                    torch_arctanh(self.light_init_vals /
                                  torch.norm(self.light_init_vals)) +
                    self.light_modifier /
                    torch.norm(self.light_modifier + delta))
                self.light_intensity = torch.norm(
                    self.light_init_vals) * torch.clamp(tanh_factor, 0, 1)

                self.light_input_orig_list.append(
                    self.light_init_vals / torch.norm(self.light_init_vals))

                self.light_input_adv_list.append(self.light_intensity)
                self.light = pyredner.PointLight(
                    position=(self.camera.position + torch.tensor(
                        (0.0, 0.0, 100.0))).to(pyredner.get_device()),
                    intensity=self.light_intensity)

            if pose_attack:
                self.angle_input_adv_list = []
                self.angle_input_orig_list = []

                self.euler_angles_modifier.data -= self.euler_angles_modifier.grad / (
                    torch.norm(self.euler_angles_modifier.grad) +
                    delta) * pose_lr
                self.euler_angles = tanh_rescale(
                    torch_arctanh(
                        torch.tensor([0., 0., 0.],
                                     device=pyredner.get_device())) +
                    self.euler_angles_modifier)
                self.angle_input_orig_list.append(
                    tanh_rescale(
                        torch_arctanh(
                            torch.tensor([0., 0., 0.],
                                         device=pyredner.get_device()))))
                self.angle_input_adv_list.append(self.euler_angles)

            img = self.render_image(out_dir=out_dir, filename=filename)

        final_pred, net_out = self.classify(img)
        final_image = np.clip(img[0].permute(1, 2, 0).data.cpu().numpy(), 0, 1)
        return final_pred, final_image
Example #30
0
     with torch.no_grad():
         if c_step+num_steps >= done_idx:
             next_value = torch.Tensor([[0]]).to(device)# vf.forward([next_obs]) 
         else:
             next_value = vf.forward([obs[end_idx]])
     batch_returns = np.append(np.zeros_like(batch_rewards), next_value[0][0].detach().cpu())
     for t in reversed(range(batch_rewards.shape[0])):
         batch_returns[t] = batch_rewards[t] + args.gamma * batch_returns[t+1] * (1-dones[t])
     batch_returns = batch_returns[:-1]
     # advantages are batch_returns - baseline, value estimates in our case
     advantages = batch_returns - values[c_step:end_idx].detach().cpu().numpy()
     
     vf_loss = loss_fn(torch.Tensor(batch_returns).to(device), values[c_step:end_idx]) * args.vf_coef
     pg_loss = torch.Tensor(advantages).to(device) * neglogprobs[c_step:end_idx]
     loss = (pg_loss - entropys[c_step:end_idx] * args.ent_coef).mean() + vf_loss
     if torch.isnan(loss):
         raise Exception()
     
     optimizer.zero_grad()
     loss.backward(retain_graph=True)
     nn.utils.clip_grad_norm_(list(pg.parameters()) + list(vf.parameters()), args.max_grad_norm)
     optimizer.step()
     c_step += num_steps
     
 # returns = np.zeros_like(rewards)
 # for t in reversed(range(rewards.shape[0]-1)):
 #     returns[t] = rewards[t] + args.gamma * returns[t+1] * (1-dones[t])
 # # advantages are returns - baseline, value estimates in our case
 # advantages = returns - values.detach().cpu().numpy()
 
 # vf_loss = loss_fn(torch.Tensor(returns).to(device), values) * args.vf_coef
Example #31
0
    def _forward(self,  # type: ignore
                 context: List[str],
                 image: torch.Tensor,
                 caption: Dict[str, torch.LongTensor]):

        # We assume that the first token in target is the <s> token. We
        # shall use it to seed the decoder. Here decoder_target is simply
        # decoder_input but shifted to the right by one step.
        caption_ids = caption[self.index]
        target_ids = torch.zeros_like(caption_ids)
        target_ids[:, :-1] = caption_ids[:, 1:]

        # The final token is not used as input to the decoder, since otherwise
        # we'll be predicting the <pad> token.
        caption_ids = caption_ids[:, :-1]
        target_ids = target_ids[:, :-1]

        # Truncate very long captions to avoid OOM errors
        caption_ids = caption_ids[:, :self.max_caption_len]
        target_ids = target_ids[:, :self.max_caption_len]

        caption[self.index] = caption_ids

        # Embed the image
        X_image = self.resnet(image)
        # X_image.shape == [batch_size, 2048, 7, 7]

        X_image = X_image.permute(0, 2, 3, 1)
        # X_image.shape == [batch_size, 7, 7, 2048]

        # Flatten out the image
        B, H, W, C = X_image.shape
        P = H * W  # number of pixels
        X_image = X_image.view(B, P, C)
        # X_image.shape == [batch_size, 49, 2048]

        # article_ids.shape == [batch_size, seq_len]

        context = [c.lower() for c in context]
        context_docs = self.nlp.pipe(context)
        vs = []
        v_lens = []
        for doc in context_docs:
            v = [token.vector for token in doc if token.has_vector]
            v_lens.append(len(v))
            vs.append(np.array(v))
        max_len = max(v_lens)

        context_vector = X_image.new_full((B, max_len, 300), np.nan)
        for i, v in enumerate(vs):
            v_len = v.shape[0]
            v_tensor = torch.from_numpy(v).type_as(context_vector)
            context_vector[i, :v_len] = v_tensor

        article_padding_mask = torch.isnan(context_vector).any(dim=-1)
        # article_padding_mask.shape == [batch_size, seq_len]

        X_article = context_vector
        X_article[article_padding_mask] = 0

        # Create padding mask (1 corresponds to the padding index)
        image_padding_mask = X_image.new_zeros(B, P).bool()

        # The quirks of dynamic convolution implementation: The context
        # embedding has dimension [seq_len, batch_size], but the mask has
        # dimension [batch_size, seq_len].
        contexts = {
            'image': X_image.transpose(0, 1),
            'image_mask': image_padding_mask,
            'article': X_article.transpose(0, 1),
            'article_mask': article_padding_mask,
            'sections': None,
            'sections_mask': None,
        }

        return caption_ids, target_ids, contexts
Example #32
0
    def attack_PGD(self,
                   label,
                   out_dir=None,
                   save_title=None,
                   steps=5,
                   vertex_epsilon=1.0,
                   pose_epsilon=1.0,
                   lighting_epsilon=8000.0,
                   vertex_lr=0.001,
                   pose_lr=0.05,
                   lighting_lr=4000.0,
                   vertex_attack=True,
                   pose_attack=True,
                   lighting_attack=False):

        if out_dir is not None and save_title is None:
            raise Exception("Must provide image title if out dir is provided")
        elif save_title is not None and out_dir is None:
            raise Exception("Must provide directory if image is to be saved")

        filename = save_title

        # classify
        img = self.render_image(out_dir=out_dir, filename=filename)

        # only there to zero out gradients.
        optimizer = torch.optim.Adam(
            [self.translation, self.euler_angles, self.light.intensity], lr=0)
        angle_perturbations = torch.tensor([0., 0.,
                                            0.]).to(pyredner.get_device())
        vertex_perturbations_lst = []
        for shape in self.shapes:
            perturbation = torch.zeros(shape.vertices.shape).to(
                pyredner.get_device())
            vertex_perturbations_lst += [perturbation]

        for i in range(steps):
            optimizer.zero_grad()
            pred, net_out = self.classify(img)
            if pred.item() != label and i != 0:
                print("misclassification at step ", i)
                final_image = np.clip(
                    img[0].permute(1, 2, 0).data.cpu().numpy(), 0, 1)
                return pred, final_image
            # get gradients
            self._get_gradients(img.cpu(), net_out, label)
            delta = 1e-6
            inf_count = 0
            nan_count = 0

            if vertex_attack:
                # attack each shape's vertices
                for i in range(len(self.shapes)):
                    shape = self.shapes[i]
                    vertex_perturbations = vertex_perturbations_lst[i]
                    if not torch.isfinite(shape.vertices.grad).all():
                        inf_count += 1
                    elif torch.isnan(shape.vertices.grad).any():
                        nan_count += 1
                    else:
                        # initial perturbation size
                        p = shape.vertices.grad / (torch.norm(
                            shape.vertices.grad) + delta) * vertex_lr
                        # ensure the perturbation doesn't exceed the ball of radius epsilon -- if it does, clip it.
                        p = torch.min(
                            torch.max(p,
                                      vertex_perturbations - vertex_epsilon),
                            vertex_perturbations + vertex_epsilon)
                        # subtract because we are trying to decrease the classification score of the label
                        shape.vertices -= p
                        vertex_perturbations -= p

            if lighting_attack:
                light_sub = self.light.intensity.grad / (torch.norm(
                    self.light.intensity.grad) + delta) * lighting_lr
                light_sub = torch.min(
                    self.light.intensity.data,
                    light_sub)  # ensure lighting never goes negative
                self.light.intensity.data = torch.min(
                    torch.max(self.light.intensity.data - light_sub,
                              self.light_init_vals - lighting_epsilon),
                    self.light_init_vals + lighting_epsilon)
                print(self.light.intensity.data)

            if pose_attack:
                # initial perturbation size
                p = self.euler_angles.grad / (
                    torch.norm(self.euler_angles.grad) + delta) * pose_lr
                # ensure the perturbation doesn't exceed the ball of radius epsilon -- if it does, clip it.
                p = torch.min(torch.max(p, angle_perturbations - pose_epsilon),
                              angle_perturbations + pose_epsilon)
                # subtract because we are trying to decrease the classification score of the label
                self.euler_angles.data -= p
                angle_perturbations -= p

            img = self.render_image(out_dir=out_dir, filename=filename)

        final_pred, net_out = self.classify(img)
        final_image = np.clip(img[0].permute(1, 2, 0).data.cpu().numpy(), 0, 1)
        return final_pred, final_image
Example #33
0
    def loss(self, H1, H2):
        """

        It is the loss function of CCA as introduced in the original paper. There can be other formulations.

        """
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        r1 = self.r1
        r2 = self.r2
        eps = 1e-9

        H1, H2 = H1.t(), H2.t()
        o1 = o2 = H1.size(0)
        #print(torch.isnan(H1).sum())
        #print(torch.isnan(H2).sum())

        assert torch.isnan(H1).sum().item() == 0
        assert torch.isnan(H2).sum().item() == 0

        o1 = o2 = H1.size(0)

        m = H1.size(1)
        #print(H1.size())

        H1bar = H1 - H1.mean(dim=1).unsqueeze(dim=1)
        H2bar = H2 - H2.mean(dim=1).unsqueeze(dim=1)
        #H1Norm = H1bar/torch.norm(H1bar, dim=0)
        #H2Norm = H2bar/torch.norm(H2bar, dim=0)

        #print(torch.matrix_rank(H1Norm))
        #print(torch.matrix_rank(H2Norm))

        SigmaHat12 = (1.0 / (m - 1)) * torch.matmul(H1bar, H2bar.t())
        SigmaHat11 = (1.0 / (m - 1)) * torch.matmul(
            H1bar, H1bar.t()) + r1 * torch.eye(o1, device=self.device)
        SigmaHat22 = (1.0 / (m - 1)) * torch.matmul(
            H2bar, H2bar.t()) + r2 * torch.eye(o2, device=self.device)

        #print(SigmaHat11)
        #print(SigmaHat12)
        #print(torch.matrix_rank(SigmaHat12))
        #print(torch.matrix_rank(SigmaHat11))
        #print(torch.matrix_rank(SigmaHat22))

        #print(torch.isnan(SigmaHat11).sum())
        #print(torch.isnan(SigmaHat12).sum())
        #print(torch.isnan(SigmaHat22).sum())

        #assert torch.isnan(SigmaHat11).sum().item() == 0
        #assert torch.isnan(SigmaHat12).sum().item() == 0
        #assert torch.isnan(SigmaHat22).sum().item() == 0

        # Calculating the root inverse of covariance matrices by using eigen decomposition
        [D1, V1] = torch.symeig(SigmaHat11, eigenvectors=True)
        [D2, V2] = torch.symeig(SigmaHat22, eigenvectors=True)

        #print('D1 is :', D1)
        #print('D2 is :', D2)
        #print(torch.isnan(D1).sum())
        #print(torch.isnan(D2).sum())
        #print(torch.isnan(V1).sum())
        #print(torch.isnan(V2).sum())

        # Added to increase stability
        posInd1 = torch.gt(D1, eps).nonzero()[:, 0]
        D1 = D1[posInd1]
        V1 = V1[:, posInd1]
        posInd2 = torch.gt(D2, eps).nonzero()[:, 0]
        D2 = D2[posInd2]
        V2 = V2[:, posInd2]
        #print(posInd1.size())
        #print(posInd2.size())
        #print(torch.isnan(posInd1).sum())
        #print(torch.isnan(posInd2).sum())
        #print(torch.isnan(D1).sum())
        #print(torch.isnan(D2).sum())
        #print(torch.isnan(V1).sum())
        #print(torch.isnan(V2).sum())

        SigmaHat11RootInv = torch.matmul(
            torch.matmul(V1, torch.diag(D1**-0.5)), V1.t())
        SigmaHat22RootInv = torch.matmul(
            torch.matmul(V2, torch.diag(D2**-0.5)), V2.t())

        Tval = torch.matmul(torch.matmul(SigmaHat11RootInv, SigmaHat12),
                            SigmaHat22RootInv)
        #         print(Tval.size())
        #print(torch.isnan(SigmaHat11RootInv).sum())
        #print(torch.isnan(SigmaHat22RootInv).sum())
        #print(torch.isnan(Tval).sum())

        if self.use_all_singular_values:
            # all singular values are used to calculate the correlation
            tmp = torch.trace(torch.matmul(Tval.t(), Tval))
            # print(tmp)
            corr = torch.sqrt(tmp)
            # assert torch.isnan(corr).item() == 0
        else:
            # just the top self.outdim_size singular values are used
            sym = torch.matmul(Tval.t(),
                               Tval) + r1 * torch.eye(o1, device=self.device)
            #print(torch.matrix_rank(sym))
            #print(torch.isnan(sym).sum())
            U, V = torch.symeig(sym, eigenvectors=True)
            #U = U[torch.gt(U, eps).nonzero()[:, 0]]
            U = U.topk(self.outdim_size)[0]
            #print('U is: ' ,U)
            corr = torch.sum(torch.sqrt(U))
            #print(torch.isnan(U).sum())
            #print(torch.isnan(V).sum())

        #print(corr)
        return -corr
Example #34
0
    def run_epoch(_loader, _model, _optimizer, _tag, _ema, _epoch, _scheduler=None, max_step=100000):
        params_without_bn = [params for name, params in _model.named_parameters() if not ('_bn' in name or '.bn' in name)]

        tta_cnt = [0] * tta_num
        metric = Accumulator()
        batch = []
        total_steps = len(_loader)
        tqdm_loader = tqdm(_loader, desc=f'[{_tag} epoch={_epoch+1:03}/{args.epoch:03}]', total=min(max_step, total_steps))
        try:
            for example_id, (img_orig, lb, losses, corrects) in enumerate(tqdm_loader):
                batch.append((img_orig, lb, losses, corrects))
                if (example_id + 1) % args.batch != 0:
                    continue

                if max_step < example_id:
                    break

                imgs = torch.cat([x[0] for x in batch]).cuda()
                lbs = torch.cat([x[1] for x in batch]).long().cuda()
                losses = torch.cat([x[2] for x in batch]).cuda()
                corrects = torch.cat([x[3] for x in batch]).cuda()
                assert len(imgs) > 0

                imgs = imgs.view(imgs.size(0) * imgs.size(1), imgs.size(2), imgs.size(3), imgs.size(4))
                lbs = lbs.view(lbs.size(0) * lbs.size(1))
                losses = losses.view(losses.size(0) * losses.size(1), -1)
                corrects = corrects.view(corrects.size(0) * corrects.size(1), -1)
                assert losses.shape[1] == tta_num, losses.shape
                assert corrects.shape[1] == tta_num, corrects.shape
                assert torch.isnan(losses).sum() == 0

                softmin_target = torch.nn.functional.softmin(losses / args.tau, dim=1).detach()
                pred = _model(imgs)
                pred_softmax = torch.nn.functional.softmax(pred, dim=1)
                assert torch.isnan(pred).sum() == 0, pred
                assert torch.isnan(pred_softmax).sum() == 0, pred_softmax
                assert torch.isnan(softmin_target).sum() == 0
                assert softmin_target.shape[0] == pred_softmax.shape[0], (softmin_target.shape, pred_softmax.shape)
                assert softmin_target.shape[1] == pred_softmax.shape[1], (softmin_target.shape, pred_softmax.shape)

                pred_final = pred_softmax
                loss = spearman_loss(pred_softmax, softmin_target)

                if _optimizer is not None:
                    loss_total = loss + args.decay * sum([torch.norm(p, p=args.regularization) for p in params_without_bn])
                    loss_total.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                if _ema is not None:
                    _ema(_model, _epoch * total_steps + example_id)

                for idx in torch.argmax(pred_softmax, dim=1):
                    tta_cnt[idx] += 1

                pred_correct = torch.Tensor([x[y] for x, y in zip(corrects, torch.argmax(pred_final, dim=1))])
                orac_correct = torch.Tensor([x[y] for x, y in zip(corrects, torch.argmax(softmin_target, dim=1))])
                defa_correct = corrects[:, encoded_tta_default()]

                pred_loss = torch.Tensor([x[y] for x, y in zip(losses, torch.argmax(pred_final, dim=1))])
                defa_loss = losses[:, encoded_tta_default()]
                corr_p = prediction_correlation(pred_final, softmin_target)

                metric.add('loss', loss.item())
                metric.add('l_l2t', torch.mean(pred_loss).item())
                metric.add('l_org', torch.mean(defa_loss).item())
                metric.add('top1_l2t', torch.mean(pred_correct).item())
                metric.add('top1_oracle', torch.mean(orac_correct).item())
                metric.add('top1_org', torch.mean(defa_correct).item())
                metric.add('corr_p', corr_p)
                metric.add('cnt', 1)
                tqdm_loader.set_postfix(
                    lr=_optimizer.param_groups[0]['lr'] if _optimizer is not None else 0,
                    l=metric['loss'] / metric['cnt'],
                    l_l2t=metric['l_l2t'] / metric['cnt'],
                    l_org=metric['l_org'] / metric['cnt'],
                    # l_curr=loss.item(),
                    corr_p=metric['corr_p'] / metric['cnt'],
                    acc_l2t=metric['top1_l2t'] / metric['cnt'],
                    acc_org=metric['top1_org'] / metric['cnt'],
                    acc_d=(metric['top1_l2t'] - metric['top1_org']) / metric['cnt'],
                    acc_O=metric['top1_oracle'] / metric['cnt'],
                    # tta_top=decode_desc(np.argmax(tta_cnt)),
                    # tta_max='%.2f(%d)' % (max(tta_cnt) / float(sum(tta_cnt)), np.argmax(tta_cnt)),
                    ttas=f'{tta_cnt[0]/sum(tta_cnt):.2f},{tta_cnt[-3]/sum(tta_cnt):.2f},{tta_cnt[-2]/sum(tta_cnt):.2f},{tta_cnt[-1]/sum(tta_cnt):.2f}'
                    # tta_min='%.2f' % (min(tta_cnt) / float(sum(tta_cnt))),
                    # grad_l2=metric['grad_l2'] / metric['cnt'],
                )

                batch = []
                if _scheduler is not None:
                    _scheduler.step(_epoch + (float(example_id) / total_steps))
                del pred, loss
        except KeyboardInterrupt as e:
            if 'test' not in _tag:
                raise e
            pass
        finally:
            tqdm_loader.close()

        del tqdm_loader, batch

        if 'test' in _tag:
            if metric['top1_l2t'] >= metric['top1_org']:
                c = 107     # green
            else:
                c = 124     # red

        else:
            if metric['top1_l2t'] >= metric['top1_org']:
                c = 149
            else:
                c = 14      # light_cyan
        logger.info(f'[{_tag} epoch={_epoch + 1}] ' + stylize(
            'loss=%.4f l(l2t=%.4f org=%.4f) top1_O=%.4f top1_org=%.4f << corr_p=%.4f delta=%.4f %s(%s)>>' %
            (metric['loss'] / metric['cnt'],
             metric['l_l2t'] / metric['cnt'], metric['l_org'] / metric['cnt'],
             metric['top1_oracle'] / metric['cnt'],
             metric['top1_l2t'] / metric['cnt'],
             metric['top1_org'] / metric['cnt'],
             metric['corr_p'] / metric['cnt'],
             (metric['top1_l2t'] / metric['cnt']) - (metric['top1_org'] / metric['cnt']),
             decode_desc(np.argmax(tta_cnt)), '%.2f(%d)' % (max(tta_cnt) / float(sum(tta_cnt)), np.argmax(tta_cnt)),
             )
        , colored.fg(c)))
        return metric
Example #35
0
def embedding_dist(x1, x2, pos_metric, tau=0.05, xent=False):

    if xent:
        #X1 denotes the batch of anchors while X2 denotes all the negative matches
        #Broadcasting to compute loss for each anchor over all the negative matches

        #Only implemnted if x1, x2 are 2 rank tensors
        if len(x1.shape) != 2 or len(x2.shape) != 2:
            print(
                'Error: both should be rank 2 tensors for NT-Xent loss computation'
            )

        #Normalizing each vector
        ## Take care to reshape the norm: For a (N*D) vector; the norm would be (N) which needs to be shaped to (N,1) to ensure row wise l2 normalization takes place
        if torch.sum(torch.isnan(x1)):
            print('X1 is nan')
            sys.exit()

        if torch.sum(torch.isnan(x2)):
            print('X1 is nan')
            sys.exit()

        eps = 1e-8

        norm = x1.norm(dim=1)
        norm = norm.view(norm.shape[0], 1)
        temp = eps * torch.ones_like(norm)

        x1 = x1 / torch.max(norm, temp)

        if torch.sum(torch.isnan(x1)):
            print('X1 Norm is nan')
            sys.exit()

        norm = x2.norm(dim=1)
        norm = norm.view(norm.shape[0], 1)
        temp = eps * torch.ones_like(norm)

        x2 = x2 / torch.max(norm, temp)

        if torch.sum(torch.isnan(x2)):
            print('Norm: ', norm, x2)
            print('X2 Norm is nan')
            sys.exit()

        # Boradcasting the anchors vector to compute loss over all negative matches
        x1 = x1.unsqueeze(1)
        cos_sim = torch.sum(x1 * x2, dim=2)
        cos_sim = cos_sim / tau

        if torch.sum(torch.isnan(cos_sim)):
            print('Cos is nan')
            sys.exit()

        loss = torch.sum(torch.exp(cos_sim), dim=1)

        if torch.sum(torch.isnan(loss)):
            print('Loss is nan')
            sys.exit()

        return loss

    else:
        if pos_metric == 'l1':
            return l1_dist(x1, x2)
        elif pos_metric == 'l2':
            return l2_dist(x1, x2)
        elif pos_metric == 'cos':
            return cosine_similarity(x1, x2)
Example #36
0
    def forward(self) -> Tuple[torch.Tensor, int]:
        unstable = 0

        # get sigma/mean or each level
        sigma_w, mu_w = self.implied_sigma_mu()
        sigma_l2, mu_l2 = self.implied_sigma_mu(suffix="_l2")

        # decompose into 11, 12, 21, 22
        sigma_b, sigma_xx, sigma_yx, mu_b, mu_x = self._split(sigma_l2, mu_l2)

        # cluster FIML -2 * logL (without constants)
        loss = torch.zeros(1, dtype=mu_w.dtype, device=mu_w.device)

        # go through each cluster separately
        data_ys_available = ~torch.isnan(self.data_ys)
        cache_S_ij = {}
        cache_S_j_R_j = {}
        sigma_b_logdet = None
        sigma_b_inv = None
        if not self.naive_implementation:
            sigma_b_logdet = torch.logdet(sigma_b)
            sigma_b_inv = torch.inverse(sigma_b)
        for cluster_slice, batches in self.missing_patterns:
            # get cluster data and define R_j for current cluster j
            cluster_x = self.data_xs[cluster_slice.start, :]
            R_j_index = ~torch.isnan(cluster_x)
            no_cluster = ~R_j_index.any()

            # cache
            key = (
                tuple(R_j_index.tolist()),
                tuple([
                    tuple(x)
                    for x in data_ys_available[cluster_slice, :].tolist()
                ]),
            )
            sigma_j_logdet, sigma_j_inv = cache_S_j_R_j.get(key, (None, None))

            # define S_ij and S_j
            S_ijs = []
            eye_w = torch.eye(mu_w.shape[0],
                              dtype=mu_w.dtype,
                              device=mu_w.device)
            lambda_ijs_logdet_sum = 0.0
            lambda_ijs_inv = []
            A_j = torch.zeros_like(sigma_w)
            for batch_slice in batches:
                size = batch_slice.stop - batch_slice.start
                available = data_ys_available[batch_slice.start]
                S_ij = eye_w[available, :]
                S_ijs.extend([S_ij] * size)

                if self.naive_implementation or sigma_j_logdet is not None:
                    continue

                key_S_ij = tuple(available.tolist())
                lambda_ij_inv, lambda_ij_logdet, a_j = cache_S_ij.get(
                    key_S_ij, (None, None, None))

                if lambda_ij_inv is None:
                    lambda_ij = sigma_w  # no missing data
                    if S_ij.shape[0] != eye_w.shape[0]:
                        # missing data
                        lambda_ij = S_ij.mm(sigma_w.mm(S_ij.t()))
                    lambda_ij_inv = torch.inverse(lambda_ij)
                    lambda_ij_logdet = torch.logdet(lambda_ij)

                    if S_ij.shape[0] != eye_w.shape[0]:
                        # missing data
                        a_j = S_ij.t().mm(lambda_ij_inv.mm(S_ij))
                    else:
                        a_j = lambda_ij_inv
                    cache_S_ij[key_S_ij] = lambda_ij_inv, lambda_ij_logdet, a_j

                lambda_ijs_inv.extend([lambda_ij_inv] * size)
                lambda_ijs_logdet_sum = lambda_ijs_logdet_sum + lambda_ij_logdet * size
                A_j = A_j + a_j * size

            S_j = torch.cat(S_ijs, dim=0)

            # means
            y_j = torch.cat([
                self.data_ys[cluster_slice, :][data_ys_available[
                    cluster_slice, :]][:, None],
                cluster_x[R_j_index, None],
            ])
            mu_y = mu_w + mu_b
            mu_j = torch.cat([S_j.mm(mu_y), mu_x[R_j_index]])
            mean_diff = y_j - mu_j
            G_yj = mean_diff.mm(mean_diff.t())

            if sigma_j_logdet is None and not self.naive_implementation:
                sigma_b_inv_A_j = sigma_b_inv + A_j
                B_j = torch.inverse(sigma_b_inv_A_j)
                C_j = eye_w - A_j.mm(B_j)
                D_j = C_j.mm(A_j)
                lambda_inv = block_diag(lambda_ijs_inv)
                V_j_inv = lambda_inv - lambda_inv.mm(
                    S_j.mm(B_j.mm(S_j.t().mm(lambda_inv))))

                if no_cluster:
                    # no cluster
                    sigma_11_j = V_j_inv
                    sigma_21_j = torch.empty(0,
                                             device=sigma_11_j.device,
                                             dtype=sigma_11_j.dtype)
                    sigma_22_1 = torch.empty([0, 0],
                                             device=sigma_11_j.device,
                                             dtype=sigma_11_j.dtype)
                    sigma_22_inv = sigma_21_j

                else:
                    # normal case
                    sigma_22_1 = (sigma_xx - sigma_yx.t().mm(
                        D_j.mm(sigma_yx)))[R_j_index, :][:, R_j_index]
                    sigma_22_inv = torch.inverse(sigma_22_1)
                    sigma_jyx = S_j.mm(sigma_yx[:, R_j_index])
                    sigma_11_j = (V_j_inv.mm(
                        sigma_jyx.mm(sigma_22_inv.mm(
                            sigma_jyx.t().mm(V_j_inv)))) + V_j_inv)
                    sigma_21_j = -sigma_22_inv.mm(sigma_jyx.t().mm(V_j_inv))

                sigma_j_inv = torch.cat(
                    [
                        torch.cat([sigma_11_j, sigma_21_j]),
                        torch.cat([sigma_21_j.t(), sigma_22_inv]),
                    ],
                    dim=1,
                )

                sigma_j_logdet = (lambda_ijs_logdet_sum + sigma_b_logdet +
                                  torch.logdet(sigma_b_inv_A_j) +
                                  torch.logdet(sigma_22_1))
                cache_S_j_R_j[key] = (sigma_j_logdet, sigma_j_inv)

            elif sigma_j_logdet is None:
                # naive
                sigma_j = S_j.mm(sigma_b.mm(S_j.t())) + block_diag(
                    [S_ij.mm(sigma_w.mm(S_ij.t())) for S_ij in S_ijs])
                if not no_cluster:
                    sigma_j_12 = S_j.mm(sigma_yx[:, R_j_index])
                    sigma_j_21 = sigma_j_12.t()
                    sigma_j_22 = sigma_xx[R_j_index, :][:, R_j_index]
                    sigma_j = torch.cat(
                        [
                            torch.cat([sigma_j, sigma_j_21]),
                            torch.cat([sigma_j_12, sigma_j_22]),
                        ],
                        dim=1,
                    )
                sigma_j_logdet = torch.logdet(sigma_j)
                sigma_j_inv = torch.inverse(sigma_j)
                cache_S_j_R_j[key] = (sigma_j_logdet, sigma_j_inv)

            loss_current = sigma_j_logdet + torch.trace(sigma_j_inv.mm(G_yj))
            unstable += loss_current.detach().item() < 0
            loss = loss + loss_current.clamp(min=0.0)

        return loss, unstable
Example #37
0
 def test_isnan(self):
     x = torch.tensor([1, float('nan'), 2])
     self.assertONNX(lambda x: torch.isnan(x), x)