Example #1
0
            writer.add_scalar('data/loss', float(loss), global_count)
            writer.add_scalar('data/acc', float(acc), global_count)
            print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(
                epoch, i, len(train_loader), loss.item(), acc))

            tl.add(loss.item())
            ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tl = tl.item()
        ta = ta.item()

        model.eval()

        vl = Averager()
        va = Averager()

        label = torch.arange(args.validation_way).repeat(args.query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)

        print('best epoch {}, best val acc={:.4f}'.format(
            trlog['max_acc_epoch'], trlog['max_acc']))
        with torch.no_grad():
            for i, batch in enumerate(val_loader, 1):
                if torch.cuda.is_available():
Example #2
0
class ProtoLearner(object):
    def __init__(self, args, mode='train'):

        # init model and optimizer
        self.model = ProtoNet(args)
        print(self.model)
        if torch.cuda.is_available():
            self.model.cuda()

        if mode == 'train':
            if args.use_attention:
                self.optimizer = torch.optim.Adam(
                    [{
                        'params': self.model.encoder.parameters(),
                        'lr': 0.0001
                    }, {
                        'params': self.model.base_learner.parameters()
                    }, {
                        'params': self.model.att_learner.parameters()
                    }],
                    lr=args.lr)
            else:
                self.optimizer = torch.optim.Adam(
                    [{
                        'params': self.model.encoder.parameters(),
                        'lr': 0.0001
                    }, {
                        'params': self.model.base_learner.parameters()
                    }, {
                        'params': self.model.linear_mapper.parameters()
                    }],
                    lr=args.lr)
            #set learning rate scheduler
            self.lr_scheduler = optim.lr_scheduler.StepLR(
                self.optimizer, step_size=args.step_size, gamma=args.gamma)
            # load pretrained model for point cloud encoding
            self.model = load_pretrain_checkpoint(
                self.model, args.pretrain_checkpoint_path)
        elif mode == 'test':
            # Load model checkpoint
            self.model = load_model_checkpoint(self.model,
                                               args.model_checkpoint_path,
                                               mode='test')
        else:
            raise ValueError('Wrong GMMLearner mode (%s)! Option:train/test' %
                             mode)

    def train(self, data):
        """
        Args:
            data: a list of torch tensors wit the following entries.
            - support_x: support point clouds with shape (n_way, k_shot, in_channels, num_points)
            - support_y: support masks (foreground) with shape (n_way, k_shot, num_points)
            - query_x: query point clouds with shape (n_queries, in_channels, num_points)
            - query_y: query labels with shape (n_queries, num_points)
        """

        [support_x, support_y, query_x, query_y] = data
        self.model.train()

        query_logits, loss = self.model(support_x, support_y, query_x, query_y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.lr_scheduler.step()

        query_pred = F.softmax(query_logits, dim=1).argmax(dim=1)
        correct = torch.eq(query_pred,
                           query_y).sum().item()  # including background class
        accuracy = correct / (query_y.shape[0] * query_y.shape[1])

        return loss, accuracy

    def test(self, data):
        """
        Args:
            support_x: support point clouds with shape (n_way, k_shot, in_channels, num_points)
            support_y: support masks (foreground) with shape (n_way, k_shot, num_points), each point \in {0,1}.
            query_x: query point clouds with shape (n_queries, in_channels, num_points)
            query_y: query labels with shape (n_queries, num_points), each point \in {0,..., n_way}
        """
        [support_x, support_y, query_x, query_y] = data
        self.model.eval()

        with torch.no_grad():
            logits, loss = self.model(support_x, support_y, query_x, query_y)
            pred = F.softmax(logits, dim=1).argmax(dim=1)
            correct = torch.eq(pred, query_y).sum().item()
            accuracy = correct / (query_y.shape[0] * query_y.shape[1])

        return pred, loss, accuracy