Ejemplo n.º 1
0
 def inference_step(self, adj_matrix, feat_matrix, labels, indices):
     """
     Forward matrix and features and calculate loss and accuracy for the labels
     """
     self.model.eval()
     out = self.model(adj_matrix, feat_matrix)
     loss = self.loss(out[indices], labels[indices]).detach()
     acc = utils.calculate_accuracy(out[indices], labels[indices])
     return loss, acc
Ejemplo n.º 2
0
 def train_step(self, adj_matrix, feat_matrix, labels, train_indices):
     self.model.train()
     self.optimizer.zero_grad()
     out = self.model(adj_matrix, feat_matrix)
     train_loss = self.loss(out[train_indices], labels[train_indices])
     train_acc = utils.calculate_accuracy(out[train_indices],
                                          labels[train_indices])
     train_loss.backward()
     self.optimizer.step()
     return train_acc, train_loss
Ejemplo n.º 3
0
 def calc_local_to_local(self, local_t, local_tp1):
     N, sy, sx, d = local_tp1.shape
     # Loss 2: f5 patches at time t, with f5 patches at time t+1
     local_t_score = self.score_fxn2(local_t.reshape(-1, d))
     transformed_local_t = local_t_score.reshape(N, sy * sx,
                                                 d).transpose(0, 1)
     local_tp1 = local_tp1.reshape(N, sy * sx, d).transpose(0, 1)
     logits2 = torch.matmul(transformed_local_t,
                            local_tp1.transpose(1, 2)).reshape(-1, N)
     target2 = torch.arange(N).repeat(sx * sy).to(self.device)
     loss2 = nn.CrossEntropyLoss()(logits2, target2)
     acc2 = calculate_accuracy(logits2.detach().cpu().numpy(),
                               target2.detach().cpu().numpy())
     return loss2, acc2
Ejemplo n.º 4
0
 def calc_global_to_local(self, global_t, local_tp1):
     N, sy, sx, d = local_tp1.shape
     # Loss 1: Global at time t, f5 patches at time t+1
     glob_score = self.score_fxn1(global_t)
     local_flattened = local_tp1.reshape(-1, d)
     # [N*sy*sx, d] @  [d, N] = [N*sy*sx, N ] -> dot product of every global vector in batch with local voxel at all spatial locations for all examples in the batch
     # then reshape to sy*sx, N, N then to sy*sx*N, N
     logits1 = torch.matmul(local_flattened, glob_score.t()).reshape(
         N, sy * sx, -1).transpose(1, 0).reshape(-1, N)
     # we now have sy*sx N x N matrices where the diagonals correspond to dot product between pairs consecutive in time at the same bagtch index
     # aka the correct answer. So the correct logit index is the diagonal sx*sy times
     target1 = torch.arange(N).repeat(sx * sy).to(self.device)
     loss1 = nn.CrossEntropyLoss()(logits1, target1)
     acc1 = calculate_accuracy(logits1.detach().cpu().numpy(),
                               target1.detach().cpu().numpy())
     return loss1, acc1
def val_epoch(epoch, data_loader, model, criterion, opt, logger):
    print('\t************** VALIDATION **************')
    model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        if not opt.no_cuda:
            targets = targets.cuda(async=True)
        inputs = Variable(inputs, volatile=True)
        targets = Variable(targets, volatile=True)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        batch_time.update(time.time() - end_time)
        end_time = time.time()

        print('\tBatch: [{0}/{1}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(i + 1,
                                                         len(data_loader),
                                                         batch_time=batch_time,
                                                         data_time=data_time,
                                                         loss=losses,
                                                         acc=accuracies))

    logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg})

    return accuracies.avg
Ejemplo n.º 6
0
    def calc_loss2(self, slots_t, slots_pos):
        """Loss 2:  Does a pair of vectors close in time come from the
                          same slot of different slots"""


        batch_size, num_slots, slot_len = slots_t.shape
        # logits: batch_size x num_slots x num_slots
        #        for each example (set of 8 slots), for each slot, dot product with every other slot at next time step
        logits = torch.matmul(self.score_matrix_2(slots_t),
                              slots_pos.transpose(2,1))
        inp = logits.reshape(batch_size * num_slots, -1)
        target = torch.cat([torch.arange(num_slots) for i in range(batch_size)]).to(self.device)
        loss2 = nn.CrossEntropyLoss()(inp, target)
        acc2  = calculate_accuracy(inp.detach().cpu().numpy(), target.detach().cpu().numpy())
        if self.training:
            self.wandb.log({"tr_acc2": acc2})
            self.wandb.log({"tr_loss2": loss2.item()})
        else:
            self.wandb.log({"val_acc2": acc2})
            self.wandb.log({"val_loss2": loss2.item()})
        return loss2
Ejemplo n.º 7
0
    def calc_loss1(self, slots_t, slots_pos):
        """Loss 1: Does a pair of slot vectors from the same slot
                   come from consecutive (or within a small window) time steps or not?"""
        batch_size, num_slots, slot_len = slots_t.shape

        # logits: num_slots x batch_size x batch_size
        #        for each slot, for each example in the batch, dot prodcut with every other example in batch
        logits = torch.matmul(self.score_matrix_1(slots_t).transpose(1, 0),
                              slots_pos.permute(1, 2, 0))

        inp = logits.reshape(num_slots*batch_size, -1)
        target = torch.cat([torch.arange(batch_size) for i in range(num_slots)]).to(self.device)
        loss1 = nn.CrossEntropyLoss()(inp, target)
        acc1 = calculate_accuracy(inp.detach().cpu().numpy(), target.detach().cpu().numpy())

        if self.training:
            self.wandb.log({"tr_acc1": acc1})
            self.wandb.log({"tr_loss1": loss1.item()})
        else:
            self.wandb.log({"val_acc1": acc1})
            self.wandb.log({"val_loss1": loss1.item()})
        return loss1
def train_epoch(epoch, data_loader, model, criterion, optimizer, opt,
                epoch_logger, batch_logger):
    print('\t*************** TRAINING ***************')
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        if not opt.no_cuda:
            targets = targets.cuda(async=True)
        inputs = Variable(inputs)
        targets = Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end_time)
        end_time = time.time()

        batch_logger.log({
            'epoch': epoch,
            'batch': i + 1,
            'iter': (epoch - 1) * len(data_loader) + (i + 1),
            'loss': losses.val,
            'acc': accuracies.val,
            'lr': optimizer.param_groups[0]['lr']
        })

        print('\tBatch: [{0}/{1}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})\t'
              'lr {lr:.5f}'.format(i + 1,
                                   len(data_loader),
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses,
                                   acc=accuracies,
                                   lr=optimizer.param_groups[-1]['lr']))

    epoch_logger.log({
        'epoch': epoch,
        'loss': losses.avg,
        'acc': accuracies.avg,
        'lr': optimizer.param_groups[0]['lr']
    })

    if epoch % opt.checkpoint == 0:
        save_file_path = os.path.join(opt.result_path,
                                      'save_{}.pth'.format(epoch))
        states = {
            'epoch': epoch + 1,
            'arch': opt.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(states, save_file_path)

    return losses.avg, accuracies.avg