コード例 #1
0
class ESNModelSNLI(torch.nn.Module):
    def __init__(self, hp, logname=''):
        super(ESNModelSNLI, self).__init__()

        input_size = 300
        output_size = 3
        self.epochs = hp['epochs']
        self.lr = hp['lr']
        self.batch_size = hp['n_batch']
        self.weight_decay = hp['weight_decay']
        self.n_layers = hp['num_layers']
        self.batch_size = hp['n_batch']
        self.reservoir_size = hp['reservoir_size']
        #self.dropout = hp['dropout']
        attention_hidden_size = hp['n_attention']
        attention_heads = hp['attention_r']

        num_directions = 2

        def cell_provider(input_size_, reservoir_size_, layer, direction):
            return ESNMultiringCell(
                input_size_,
                reservoir_size_,
                bias=True,
                contractivity_coeff=hp['scale_rec'][layer]
                if direction == 0 else hp['scale_rec_bw'][layer],
                scale_in=hp['scale_in'][layer]
                if direction == 0 else hp['scale_in_bw'][layer],
                density_in=hp['scale_in'][layer]
                if direction == 0 else hp['density_in_bw'][layer],
                leaking_rate=hp['leaking_rate'][layer]
                if direction == 0 else hp['leaking_rate_bw'][layer])

        self.esn = ESNBase(cell_provider,
                           input_size,
                           self.reservoir_size,
                           num_layers=self.n_layers,
                           bidirectional=True).to(device)

        # Dimensionality reduction
        self.ff1 = torch.nn.Linear(
            self.n_layers * num_directions * self.reservoir_size,
            attention_hidden_size).to(device)

        # Pairwise attention for SNLI
        self.attn = SNLIAttention(attention_hidden_size,
                                  r=attention_heads).to(device)

        # Classifier
        self.classifier = torch.nn.Linear(attention_hidden_size,
                                          output_size).to(device)

        self.early_stop = self.epochs < 0
        self.epochs = abs(self.epochs)

        self.training_time = -1
        self.actual_epochs = self.epochs

    def forward(self, x1: torch.Tensor, x2: torch.Tensor):
        """
        input: (seq_len, batch_size, input_size)
        output: (batch_size, N_Y)
        """

        s1, _ = self.esn.forward(x1.to(
            device))  # states: (seq_len, batch, num_directions * hidden_size)
        s1 = torch.tanh(self.ff1(s1))  # states: (seq_len, batch, n_attn)

        s2, _ = self.esn.forward(x2.to(
            device))  # states: (seq_len, batch, num_directions * hidden_size)
        s2 = torch.tanh(self.ff1(s2))  # states: (seq_len, batch, n_attn)

        # Apply Attention
        embedding = self.attn.forward(s1, s2)

        return self.classifier(embedding)

    def forward_in_batches(self, dataset, batch_size):
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                collate_fn=collate_fn,
                                pin_memory=True,
                                num_workers=6)

        _Xs = []
        for _, minibatch in enumerate(dataloader):
            _Xs += [
                self.forward(minibatch['x1'].to(device, non_blocking=True),
                             minibatch['x2'].to(device, non_blocking=True))
            ]

        return torch.cat(_Xs, dim=0)  # FIXME Too slow

    def fit(self, train_fold, val_fold):
        """
        Fits the model with self.alpha as regularization parameter.
        :param train_fold: training fold.
        :param val_fold: validation fold.
        :return:
        """

        if self.early_stop and val_fold is None:
            raise Exception(
                "User requested early stopping but a validation set was not provided"
            )

        t_train_start = time.time()

        #weights = 1.0 / torch.Tensor([67, 945, 1004, 949, 670, 727])
        #sampler = torch.utils.data.WeightedRandomSampler(weights, len(train_fold))
        #dataloader = DataLoader(train_fold, batch_size=self.batch_size, collate_fn=collate_fn,
        #                        pin_memory=True, sampler=sampler)

        dataloader = DataLoader(train_fold,
                                shuffle=True,
                                batch_size=self.batch_size,
                                collate_fn=collate_fn,
                                pin_memory=True)

        optimizer = torch.optim.Adam(self.parameters(),
                                     lr=self.lr,
                                     weight_decay=self.weight_decay)

        criterion = torch.nn.CrossEntropyLoss()

        checkpoint = self.state_dict()
        best_val_accuracy = 0
        epochs_without_val_acc_improvement = 0
        patience = 10
        epoch = 0
        #for epoch in tqdm(range(1, epochs + 1), desc="epochs", dynamic_ncols=True):
        for epoch in range(1, self.epochs + 1):
            running_loss = 0.0
            num_minibatches = 0
            for i, data in enumerate(dataloader):
                # Move data to devices
                data_x1 = data['x1'].to(device, non_blocking=True)
                data_x2 = data['x2'].to(device, non_blocking=True)
                data_y = data['y'].to(device, non_blocking=True)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                self.train()

                train_out = self.forward(data_x1, data_x2)

                loss = criterion(train_out, data_y)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                num_minibatches += 1

            curr_avg_loss = running_loss / num_minibatches
            if math.isnan(curr_avg_loss):
                print("Loss is NaN. Stopping.")
                break

            if val_fold is not None:
                _, val_accuracy, _ = self.performance(None, val_fold, None)

                if val_accuracy > best_val_accuracy:
                    epochs_without_val_acc_improvement = 0
                    best_val_accuracy = val_accuracy
                    checkpoint = self.state_dict()
                else:
                    epochs_without_val_acc_improvement += 1

                # Early stopping
                if self.early_stop and epochs_without_val_acc_improvement >= patience:
                    print(
                        f"Epoch {epoch}: no accuracy improvement after {patience} epochs. Early stop."
                    )
                    self.load_state_dict(checkpoint)
                    break

        self.actual_epochs = epoch - patience if self.early_stop else epoch

        t_train_end = time.time()
        self.training_time = t_train_end - t_train_start

        # Compute accuracy on validation set
        _, val_accuracy, _ = self.performance(None, val_fold, None)
        return val_accuracy

    def performance(self,
                    train_fold,
                    val_fold,
                    test_fold=None,
                    batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size

        if train_fold:
            train_accuracy, train_out, train_expected = self.performance_from_fold(
                train_fold, batch_size)
        else:
            train_accuracy, train_out, train_expected = (0, None, None)

        if val_fold:
            val_accuracy, val_out, val_expected = self.performance_from_fold(
                val_fold, batch_size)
        else:
            val_accuracy, val_out, val_expected = (0, None, None)

        if test_fold:
            test_accuracy, test_out, test_expected = self.performance_from_fold(
                test_fold, batch_size)
        else:
            test_accuracy, test_out, test_expected = (0, None, None)

        save_raw_predictions = False
        if save_raw_predictions:
            raw_preds_filename = '/home/disarli/tmp/predictions.pt'
            try:
                saved = torch.load(raw_preds_filename)
            except FileNotFoundError:
                saved = []
            saved.append({
                'train_out':
                train_out.cpu(),
                'train_expected':
                train_expected.cpu(),
                'val_out':
                val_out.cpu(),
                'val_expected':
                val_expected.cpu(),
                'test_out':
                test_out.cpu() if test_fold else None,
                'test_expected':
                test_expected.cpu() if test_fold else None,
            })
            torch.save(saved, raw_preds_filename)

        return train_accuracy, val_accuracy, test_accuracy

    def performance_from_out(self, output, expected):
        """
        Given a tensor of network outputs and a tensor of expected outputs, returns the performance
        :param output:
        :param expected:
        :return:
        """
        output = output.argmax(dim=1).cpu()

        return common.accuracy(output, expected)

    def performance_from_fold(self, fold, batch_size):
        with torch.no_grad():
            self.eval()

            out = self.forward_in_batches(fold, batch_size)
            expected = torch.Tensor([d['y'] for d in fold])

            perf = self.performance_from_out(out, expected)
            return perf, out, expected
class LeakyESNAttention(torch.nn.Module):
    def __init__(self,
                 input_size,
                 output_size,
                 reservoir_size: int = 100,
                 num_esn_layers: int = 1,
                 mlp_n_hidden: int = 1,
                 mlp_hidden_size: int = 100,
                 dropout: float = 0,
                 attention_type: str = 'LinSelfAttention',
                 attention_hidden_size: int = 100,
                 attention_heads: int = 1,
                 scale_rec: List[float] = (1, ),
                 scale_rec_bw: List[float] = (1, ),
                 scale_in: List[float] = (1, ),
                 scale_in_bw: List[float] = (1, ),
                 density_in: List[float] = (1, ),
                 density_in_bw: List[float] = (1, ),
                 leaking_rate: List[float] = (1, ),
                 leaking_rate_bw: List[float] = (1, )):
        super(LeakyESNAttention, self).__init__()

        bidirectional = True

        self.input_size = input_size
        self.output_size = output_size
        self.n_layers = num_esn_layers
        self.reservoir_size = reservoir_size
        self.dropout = dropout
        self.attention_type = attention_type
        self.mlp_n_hidden = mlp_n_hidden
        self.mlp_hidden_size = mlp_hidden_size

        num_directions = 2 if bidirectional else 1

        def cell_provider(input_size_, reservoir_size_, layer, direction):
            return ESNMultiringCell(input_size_,
                                    reservoir_size_,
                                    bias=True,
                                    contractivity_coeff=scale_rec[layer]
                                    if direction == 0 else scale_rec_bw[layer],
                                    scale_in=scale_in[layer]
                                    if direction == 0 else scale_in_bw[layer],
                                    density_in=density_in[layer] if direction
                                    == 0 else density_in_bw[layer],
                                    leaking_rate=leaking_rate[layer] if
                                    direction == 0 else leaking_rate_bw[layer])

        self.esn = ESNBase(cell_provider,
                           input_size,
                           reservoir_size,
                           num_layers=self.n_layers,
                           bidirectional=bidirectional)

        if self.attention_type == 'LinSelfAttention':
            self.ff1 = torch.nn.Linear(
                self.n_layers * num_directions * reservoir_size,
                attention_hidden_size)
            self.attn = LinSelfAttention(attention_hidden_size,
                                         r=attention_heads)
            mlp_input_size = self.attn.output_features()
        elif self.attention_type == 'Attention':
            self.ff1 = torch.nn.Linear(
                self.n_layers * num_directions * reservoir_size,
                attention_hidden_size)
            self.attn = SingleTargetAttention(attention_hidden_size)
            mlp_input_size = self.attn.output_features()
        elif self.attention_type == 'MaxPooling':
            self.ff1 = torch.nn.Linear(
                self.n_layers * num_directions * reservoir_size,
                attention_hidden_size)
            mlp_input_size = attention_hidden_size
            self.esn_bn = torch.nn.BatchNorm1d(mlp_input_size)
        elif self.attention_type == 'None':
            mlp_input_size = self.n_layers * num_directions * reservoir_size
        elif self.attention_type == 'Mean':
            mlp_input_size = self.n_layers * num_directions * reservoir_size
        else:
            raise Exception("Invalid attention type: " + self.attention_type)

        if mlp_n_hidden == 0:
            self.mlp_hn = torch.nn.ModuleList([])
            self.mlp_out = torch.nn.Linear(mlp_input_size, output_size)
        else:
            mlp_h1 = torch.nn.Linear(mlp_input_size, mlp_hidden_size)
            self.mlp_hn = torch.nn.ModuleList([
                torch.nn.Linear(mlp_hidden_size, mlp_hidden_size)
                for _ in range(mlp_n_hidden - 1)
            ])
            self.mlp_hn.insert(0, mlp_h1)
            self.mlp_out = torch.nn.Linear(mlp_hidden_size, output_size)

        self._attnweights = None  # Attention weights for the latest minibatch.

    def readout(self, states: torch.Tensor):
        """
        :param states: (seq_len, batch_size, num_directions * hidden_size)
        :return:
        """
        s = states
        for lyr in self.mlp_hn:
            s = torch.nn.functional.dropout(s, p=self.dropout)
            s = torch.relu(lyr(s))

        s = self.mlp_out(s)
        return s

    def forward(self, input: torch.Tensor, seq_lengths=None):
        """
        input: (seq_len, batch_size, input_size)
        lengths: integer list of lengths, one for each sequence in 'input'. If provided, padding states
                 are automatically ignored.
        output: (1,)
        """
        seq_len = input.size(0)
        batch = input.size(1)

        if self.attention_type == 'LinSelfAttention' or self.attention_type == 'Attention':
            n_attn = self.ff1.out_features

            states, _ = self.esn.forward(
                input
            )  # states: (seq_len, batch, num_directions * hidden_size)
            states = torch.tanh(
                self.ff1(states))  # states: (seq_len, batch, n_attn)

            ## Let the recurrent network compute the states
            #states = torch.empty((seq_len, batch, n_attn), device=input.device)
            #for i, x in enumerate(self.esn.forward_long_sequence_yld(input)):
            #    # x: (num_layers * num_directions, batch, hidden_size)
            #    x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
            #    # Reduce dimensionality. x: (batch, n_attention)
            #    x = torch.tanh(self.ff1(x))
            #    states[i] = x

            # Apply Attention
            x, self._attnweights = self.attn.forward(
                states)  # x: (batch, n_attention)

        elif self.attention_type == 'MaxPooling':
            s = torch.zeros((batch, self.ff1.out_features),
                            device=input.device)
            for i, x in enumerate(self.esn.forward_long_sequence_yld(input)):
                # x: (num_layers * num_directions, batch, hidden_size)
                x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
                # x: (batch, num_layers * num_directions * hidden_size)
                x = torch.tanh(self.ff1(x))
                s = torch.max(torch.stack((s, x)), 0)[0]
            x = self.esn_bn(s)

        elif self.attention_type == 'Mean':
            n_res = self.esn.num_layers * self.esn.num_directions * self.esn.reservoir_size

            # Let the recurrent network compute the states
            states = torch.empty((batch, n_res), device=input.device)
            count = 0
            for i, x in enumerate(self.esn.forward_long_sequence_yld(input)):
                # x: (num_layers * num_directions, batch, hidden_size)
                x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
                states += x
                count += 1
            x = states / count

        elif self.attention_type == 'None':
            x = self.esn.forward_long_sequence(input, seq_lengths=seq_lengths)
            x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1)

        return self.readout(x)

    def loss_penalty(self):
        if self.attention_type == 'LinSelfAttention' and self.attn.r > 1:
            return LinSelfAttention.loss_penalization(self._attnweights)
        return 0