class ESNModelSNLI(torch.nn.Module): def __init__(self, hp, logname=''): super(ESNModelSNLI, self).__init__() input_size = 300 output_size = 3 self.epochs = hp['epochs'] self.lr = hp['lr'] self.batch_size = hp['n_batch'] self.weight_decay = hp['weight_decay'] self.n_layers = hp['num_layers'] self.batch_size = hp['n_batch'] self.reservoir_size = hp['reservoir_size'] #self.dropout = hp['dropout'] attention_hidden_size = hp['n_attention'] attention_heads = hp['attention_r'] num_directions = 2 def cell_provider(input_size_, reservoir_size_, layer, direction): return ESNMultiringCell( input_size_, reservoir_size_, bias=True, contractivity_coeff=hp['scale_rec'][layer] if direction == 0 else hp['scale_rec_bw'][layer], scale_in=hp['scale_in'][layer] if direction == 0 else hp['scale_in_bw'][layer], density_in=hp['scale_in'][layer] if direction == 0 else hp['density_in_bw'][layer], leaking_rate=hp['leaking_rate'][layer] if direction == 0 else hp['leaking_rate_bw'][layer]) self.esn = ESNBase(cell_provider, input_size, self.reservoir_size, num_layers=self.n_layers, bidirectional=True).to(device) # Dimensionality reduction self.ff1 = torch.nn.Linear( self.n_layers * num_directions * self.reservoir_size, attention_hidden_size).to(device) # Pairwise attention for SNLI self.attn = SNLIAttention(attention_hidden_size, r=attention_heads).to(device) # Classifier self.classifier = torch.nn.Linear(attention_hidden_size, output_size).to(device) self.early_stop = self.epochs < 0 self.epochs = abs(self.epochs) self.training_time = -1 self.actual_epochs = self.epochs def forward(self, x1: torch.Tensor, x2: torch.Tensor): """ input: (seq_len, batch_size, input_size) output: (batch_size, N_Y) """ s1, _ = self.esn.forward(x1.to( device)) # states: (seq_len, batch, num_directions * hidden_size) s1 = torch.tanh(self.ff1(s1)) # states: (seq_len, batch, n_attn) s2, _ = self.esn.forward(x2.to( device)) # states: (seq_len, batch, num_directions * hidden_size) s2 = torch.tanh(self.ff1(s2)) # states: (seq_len, batch, n_attn) # Apply Attention embedding = self.attn.forward(s1, s2) return self.classifier(embedding) def forward_in_batches(self, dataset, batch_size): dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, pin_memory=True, num_workers=6) _Xs = [] for _, minibatch in enumerate(dataloader): _Xs += [ self.forward(minibatch['x1'].to(device, non_blocking=True), minibatch['x2'].to(device, non_blocking=True)) ] return torch.cat(_Xs, dim=0) # FIXME Too slow def fit(self, train_fold, val_fold): """ Fits the model with self.alpha as regularization parameter. :param train_fold: training fold. :param val_fold: validation fold. :return: """ if self.early_stop and val_fold is None: raise Exception( "User requested early stopping but a validation set was not provided" ) t_train_start = time.time() #weights = 1.0 / torch.Tensor([67, 945, 1004, 949, 670, 727]) #sampler = torch.utils.data.WeightedRandomSampler(weights, len(train_fold)) #dataloader = DataLoader(train_fold, batch_size=self.batch_size, collate_fn=collate_fn, # pin_memory=True, sampler=sampler) dataloader = DataLoader(train_fold, shuffle=True, batch_size=self.batch_size, collate_fn=collate_fn, pin_memory=True) optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) criterion = torch.nn.CrossEntropyLoss() checkpoint = self.state_dict() best_val_accuracy = 0 epochs_without_val_acc_improvement = 0 patience = 10 epoch = 0 #for epoch in tqdm(range(1, epochs + 1), desc="epochs", dynamic_ncols=True): for epoch in range(1, self.epochs + 1): running_loss = 0.0 num_minibatches = 0 for i, data in enumerate(dataloader): # Move data to devices data_x1 = data['x1'].to(device, non_blocking=True) data_x2 = data['x2'].to(device, non_blocking=True) data_y = data['y'].to(device, non_blocking=True) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize self.train() train_out = self.forward(data_x1, data_x2) loss = criterion(train_out, data_y) loss.backward() optimizer.step() running_loss += loss.item() num_minibatches += 1 curr_avg_loss = running_loss / num_minibatches if math.isnan(curr_avg_loss): print("Loss is NaN. Stopping.") break if val_fold is not None: _, val_accuracy, _ = self.performance(None, val_fold, None) if val_accuracy > best_val_accuracy: epochs_without_val_acc_improvement = 0 best_val_accuracy = val_accuracy checkpoint = self.state_dict() else: epochs_without_val_acc_improvement += 1 # Early stopping if self.early_stop and epochs_without_val_acc_improvement >= patience: print( f"Epoch {epoch}: no accuracy improvement after {patience} epochs. Early stop." ) self.load_state_dict(checkpoint) break self.actual_epochs = epoch - patience if self.early_stop else epoch t_train_end = time.time() self.training_time = t_train_end - t_train_start # Compute accuracy on validation set _, val_accuracy, _ = self.performance(None, val_fold, None) return val_accuracy def performance(self, train_fold, val_fold, test_fold=None, batch_size=None): if batch_size is None: batch_size = self.batch_size if train_fold: train_accuracy, train_out, train_expected = self.performance_from_fold( train_fold, batch_size) else: train_accuracy, train_out, train_expected = (0, None, None) if val_fold: val_accuracy, val_out, val_expected = self.performance_from_fold( val_fold, batch_size) else: val_accuracy, val_out, val_expected = (0, None, None) if test_fold: test_accuracy, test_out, test_expected = self.performance_from_fold( test_fold, batch_size) else: test_accuracy, test_out, test_expected = (0, None, None) save_raw_predictions = False if save_raw_predictions: raw_preds_filename = '/home/disarli/tmp/predictions.pt' try: saved = torch.load(raw_preds_filename) except FileNotFoundError: saved = [] saved.append({ 'train_out': train_out.cpu(), 'train_expected': train_expected.cpu(), 'val_out': val_out.cpu(), 'val_expected': val_expected.cpu(), 'test_out': test_out.cpu() if test_fold else None, 'test_expected': test_expected.cpu() if test_fold else None, }) torch.save(saved, raw_preds_filename) return train_accuracy, val_accuracy, test_accuracy def performance_from_out(self, output, expected): """ Given a tensor of network outputs and a tensor of expected outputs, returns the performance :param output: :param expected: :return: """ output = output.argmax(dim=1).cpu() return common.accuracy(output, expected) def performance_from_fold(self, fold, batch_size): with torch.no_grad(): self.eval() out = self.forward_in_batches(fold, batch_size) expected = torch.Tensor([d['y'] for d in fold]) perf = self.performance_from_out(out, expected) return perf, out, expected
class LeakyESNAttention(torch.nn.Module): def __init__(self, input_size, output_size, reservoir_size: int = 100, num_esn_layers: int = 1, mlp_n_hidden: int = 1, mlp_hidden_size: int = 100, dropout: float = 0, attention_type: str = 'LinSelfAttention', attention_hidden_size: int = 100, attention_heads: int = 1, scale_rec: List[float] = (1, ), scale_rec_bw: List[float] = (1, ), scale_in: List[float] = (1, ), scale_in_bw: List[float] = (1, ), density_in: List[float] = (1, ), density_in_bw: List[float] = (1, ), leaking_rate: List[float] = (1, ), leaking_rate_bw: List[float] = (1, )): super(LeakyESNAttention, self).__init__() bidirectional = True self.input_size = input_size self.output_size = output_size self.n_layers = num_esn_layers self.reservoir_size = reservoir_size self.dropout = dropout self.attention_type = attention_type self.mlp_n_hidden = mlp_n_hidden self.mlp_hidden_size = mlp_hidden_size num_directions = 2 if bidirectional else 1 def cell_provider(input_size_, reservoir_size_, layer, direction): return ESNMultiringCell(input_size_, reservoir_size_, bias=True, contractivity_coeff=scale_rec[layer] if direction == 0 else scale_rec_bw[layer], scale_in=scale_in[layer] if direction == 0 else scale_in_bw[layer], density_in=density_in[layer] if direction == 0 else density_in_bw[layer], leaking_rate=leaking_rate[layer] if direction == 0 else leaking_rate_bw[layer]) self.esn = ESNBase(cell_provider, input_size, reservoir_size, num_layers=self.n_layers, bidirectional=bidirectional) if self.attention_type == 'LinSelfAttention': self.ff1 = torch.nn.Linear( self.n_layers * num_directions * reservoir_size, attention_hidden_size) self.attn = LinSelfAttention(attention_hidden_size, r=attention_heads) mlp_input_size = self.attn.output_features() elif self.attention_type == 'Attention': self.ff1 = torch.nn.Linear( self.n_layers * num_directions * reservoir_size, attention_hidden_size) self.attn = SingleTargetAttention(attention_hidden_size) mlp_input_size = self.attn.output_features() elif self.attention_type == 'MaxPooling': self.ff1 = torch.nn.Linear( self.n_layers * num_directions * reservoir_size, attention_hidden_size) mlp_input_size = attention_hidden_size self.esn_bn = torch.nn.BatchNorm1d(mlp_input_size) elif self.attention_type == 'None': mlp_input_size = self.n_layers * num_directions * reservoir_size elif self.attention_type == 'Mean': mlp_input_size = self.n_layers * num_directions * reservoir_size else: raise Exception("Invalid attention type: " + self.attention_type) if mlp_n_hidden == 0: self.mlp_hn = torch.nn.ModuleList([]) self.mlp_out = torch.nn.Linear(mlp_input_size, output_size) else: mlp_h1 = torch.nn.Linear(mlp_input_size, mlp_hidden_size) self.mlp_hn = torch.nn.ModuleList([ torch.nn.Linear(mlp_hidden_size, mlp_hidden_size) for _ in range(mlp_n_hidden - 1) ]) self.mlp_hn.insert(0, mlp_h1) self.mlp_out = torch.nn.Linear(mlp_hidden_size, output_size) self._attnweights = None # Attention weights for the latest minibatch. def readout(self, states: torch.Tensor): """ :param states: (seq_len, batch_size, num_directions * hidden_size) :return: """ s = states for lyr in self.mlp_hn: s = torch.nn.functional.dropout(s, p=self.dropout) s = torch.relu(lyr(s)) s = self.mlp_out(s) return s def forward(self, input: torch.Tensor, seq_lengths=None): """ input: (seq_len, batch_size, input_size) lengths: integer list of lengths, one for each sequence in 'input'. If provided, padding states are automatically ignored. output: (1,) """ seq_len = input.size(0) batch = input.size(1) if self.attention_type == 'LinSelfAttention' or self.attention_type == 'Attention': n_attn = self.ff1.out_features states, _ = self.esn.forward( input ) # states: (seq_len, batch, num_directions * hidden_size) states = torch.tanh( self.ff1(states)) # states: (seq_len, batch, n_attn) ## Let the recurrent network compute the states #states = torch.empty((seq_len, batch, n_attn), device=input.device) #for i, x in enumerate(self.esn.forward_long_sequence_yld(input)): # # x: (num_layers * num_directions, batch, hidden_size) # x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1) # # Reduce dimensionality. x: (batch, n_attention) # x = torch.tanh(self.ff1(x)) # states[i] = x # Apply Attention x, self._attnweights = self.attn.forward( states) # x: (batch, n_attention) elif self.attention_type == 'MaxPooling': s = torch.zeros((batch, self.ff1.out_features), device=input.device) for i, x in enumerate(self.esn.forward_long_sequence_yld(input)): # x: (num_layers * num_directions, batch, hidden_size) x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1) # x: (batch, num_layers * num_directions * hidden_size) x = torch.tanh(self.ff1(x)) s = torch.max(torch.stack((s, x)), 0)[0] x = self.esn_bn(s) elif self.attention_type == 'Mean': n_res = self.esn.num_layers * self.esn.num_directions * self.esn.reservoir_size # Let the recurrent network compute the states states = torch.empty((batch, n_res), device=input.device) count = 0 for i, x in enumerate(self.esn.forward_long_sequence_yld(input)): # x: (num_layers * num_directions, batch, hidden_size) x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1) states += x count += 1 x = states / count elif self.attention_type == 'None': x = self.esn.forward_long_sequence(input, seq_lengths=seq_lengths) x = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1) return self.readout(x) def loss_penalty(self): if self.attention_type == 'LinSelfAttention' and self.attn.r > 1: return LinSelfAttention.loss_penalization(self._attnweights) return 0