Пример #1
0
    def _forward(self, x, weights):
        ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2]
        x = pad_crop(x,
                     self.largest_w,
                     self.largest_h,
                     input_w,
                     input_h,
                     goal='pad')
        x = x.unsqueeze(0)

        out = F.conv2d(x, weights["conv1.weight"], padding=1)
        out = F.relu(out)

        for i, layer in enumerate(self.layer1):
            out = layer._forward(1, i, weights, out)
        for i, layer in enumerate(self.layer2):
            out = layer._forward(2, i, weights, out)
        for i, layer in enumerate(self.layer3):
            out = layer._forward(3, i, weights, out)
        for i, layer in enumerate(self.layer4):
            out = layer._forward(4, i, weights, out)

        out = out.view(out.size(0), -1)
        out = F.linear(out, weights["linear.weight"], weights["linear.bias"])

        x = torch.reshape(out, (1, 10, self.largest_w, self.largest_h))
        return x
Пример #2
0
 def forward(self, x):
     ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2]
     x = pad_crop(x,
                  self.largest_w,
                  self.largest_h,
                  input_w,
                  input_h,
                  goal='pad')
     x = x.view(-1).unsqueeze(0)
     x = self.linear(x)
     x = torch.reshape(x, (1, 10, self.largest_w, self.largest_h))
     x = pad_crop(x,
                  input_w,
                  input_h,
                  self.largest_w,
                  self.largest_h,
                  goal='crop')
     return x
Пример #3
0
 def forward(self, x):
     ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2]
     x = pad_crop(x,
                  self.largest_w,
                  self.largest_h,
                  input_w,
                  input_h,
                  goal='pad')
     x = x.view(ch, -1).unsqueeze(0)
     x, _ = self.lstm(x)
     if self.attention_value is not None:
         x = torch.reshape(x, (1, x.shape[2], 10))
         x = self.attention(x)
     x = torch.reshape(x, (1, 10, self.largest_w, self.largest_h))
     x = pad_crop(x,
                  input_w,
                  input_h,
                  self.largest_w,
                  self.largest_h,
                  goal='crop')
     return x
Пример #4
0
 def forward(self, x):
     ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2]
     x = pad_crop(x,
                  self.largest_w,
                  self.largest_h,
                  input_w,
                  input_h,
                  goal='pad')
     x = x.view(-1).unsqueeze(0)
     x = self.linear(x)
     x = F.leaky_relu(x)
     x = self.linear2(x)
     x = torch.reshape(x, (1, 10, self.out_dim[0], self.out_dim[1]))
     return x
Пример #5
0
 def forward(self, x):
     ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2]
     x = x.unsqueeze(0)
     x = pad_crop(x,
                  self.out_dim[0],
                  self.out_dim[1],
                  input_w,
                  input_h,
                  goal='crop')
     x = self.conv(x)
     x = self.conv2(x)
     if self.attention_value is not None:
         x = torch.reshape(x, (1, x.shape[2], x.shape[3], 10))
         x = self.attention(x)
         x = torch.reshape(x, (1, 10, x.shape[1], x.shape[2]))
     return x
Пример #6
0
    def forward(self, x):
        ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2]
        x = pad_crop(x,
                     self.largest_w,
                     self.largest_h,
                     input_w,
                     input_h,
                     goal='pad')
        x = x.unsqueeze(0)

        out = F.relu(self.conv1(x))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)

        x = torch.reshape(out, (1, 10, self.largest_w, self.largest_h))
        return x
Пример #7
0
    def train(self, tasks, model_name, criterion, n_epoch=30, lr=0.1, device = "cpu", verbose=True,  inner_lr = 0.1, inner_iter = 10, meta_size = 50):
        """
        trains the given model
        
        args:
            task: task to train the model on
            criterion: loss for train
            model_name: name of the model to train
            n_epoch: number of epochs for training
            lr: learning rate for the optimization
            device: whether to use CPU or GPU
            verbose: if set to True prints additional info
            inner_lr: if using a meta-learning algorithm, the learning rate of the inner loop
            inner_iter: if using a meta-learning algorithm, the iterations of the inner loop
            meta_size: if using a meta-learning algorithm, how big to set the meta samples size
        
        returns: the obtained metrics, the final predictions and the wrong ones
        """
        
        total_loss_train = []
        total_loss_val = []
        total_accuracy_val = []
        total_accuracy_train = []
        total_accuracy_val_pix = []
        total_accuracy_train_pix = []
        total_predictions = []
        total_wrong_predictions = []
        
        criterion = criterion()
        
        sh1_big = 0
        sh2_big = 0
        for task in tasks:
            for i in range(len(task['train'])):
                sh1 = task['train'][i]['input'].shape[0]
                sh2 = task['train'][i]['input'].shape[1]
                if sh1 > sh1_big:
                    sh1_big = sh1
                if sh2 > sh2_big:
                    sh2_big = sh2     
            for i in range(len(task['test'])):
                sh1 = task['test'][i]['input'].shape[0]
                sh2 = task['test'][i]['input'].shape[1]
                if sh1 > sh1_big:
                    sh1_big = sh1
                if sh2 > sh2_big:
                    sh2_big = sh2  
        net = model_name(device, sh1_big, sh2_big).to(device)
        optimizer = Adam(net.parameters(), lr = lr)
        
        for epoch in tqdm(range(n_epoch)):

            losses = []
            for task in tasks:
            
                  loss_train = []
                  loss_val = []
                  accuracy_val = []
                  accuracy_train = []
                  accuracy_val_pix = []
                  accuracy_train_pix = []

                  inputs = []
                  outputs = []
                  for sample in task["train"]:
                      x = Tensor(sample['input'])
                      x = pad_crop(x, sh1_big, sh2_big, x.shape[0], x.shape[1], goal = "pad")
                      inputs.append(FloatTensor(expand(x).float()).to(device))
                      y = Tensor(sample['output'])
                      y = pad_crop(y, sh1_big, sh2_big, y.shape[0], y.shape[1], goal = "pad")
                      outputs.append(LongTensor(y.long()).unsqueeze(0).to(device))

                  inputs_train = inputs[:meta_size]
                  inputs_val = inputs[meta_size:]
                  outputs_train = outputs[:meta_size]
                  outputs_val = outputs[meta_size:]


                  fast_weights = OrderedDict(net.named_parameters())

                  for _ in range(inner_iter):
                      grads = []
                      loss = 0
                      for x,y in zip(inputs_train, outputs_train):
                          logits = net._forward(x.to(device), fast_weights)
                          loss += criterion(logits.to(device), y.to(device))
                      loss /= len(inputs_train)
                      gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                      fast_weights = OrderedDict((name, param - inner_lr * grad)
                                for ((name, param), grad) in zip(fast_weights.items(), gradients))

                  loss = 0    
                  for x,y in zip(inputs_val, outputs_val):
                      logits = net._forward(x.to(device), fast_weights)
                      loss += criterion(logits.to(device), y.to(device))

                  loss /= len(inputs_val)
                  loss.backward(retain_graph=True)
                  losses.append(loss.float())
                  gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

            net.train()
            optimizer.zero_grad()
            meta_loss = torch.stack(losses).mean()
            meta_loss.backward()
            optimizer.step()

            net.eval()
            with torch.no_grad():

                correct_val = 0
                correct_val_pix = 0
                total_val = 0
                loss_iter_val = 0
                predictions = []
                wrong_pred = []
                n_pixels_val = 0
                for sample in task['test']:

                    img = FloatTensor(expand(sample['input'])).to(device)
                    y = LongTensor(sample['output'])
                    labels = pad_crop(y, sh1_big, sh2_big, y.shape[0], y.shape[1], "pad").unsqueeze(0).to(device)
                    outputs = net(img)
                    _, pred = torch.max(outputs.data, 1)
                    predictions.append((img, pred))  
                    n_pixels_val += pred.shape[1]*pred.shape[2]

                    total_val += labels.size(0)
                    flag =  (torch.all(torch.eq(pred, labels))).sum().item() 
                    correct_val += flag
                    if flag == 0:
                        wrong_pred.append((img, pred))
                    correct_val_pix += (torch.eq(pred, labels)).sum().item()            
                    loss = criterion(outputs, labels)
                    loss_iter_val += loss.item()

            correct_train = 0
            correct_train_pix = 0 
            total_train = 0
            loss_iter_train = 0
            n_pixels_train = 0
            for sample in task['train']:

                img = FloatTensor(expand(sample['input'])).to(device)
                y = LongTensor(sample['output'])
                labels = pad_crop(y, sh1_big, sh2_big, y.shape[0], y.shape[1], "pad").unsqueeze(0).to(device)
                outputs = net(img)
                _, pred = torch.max(outputs.data, 1)
                n_pixels_train += pred.shape[1]*pred.shape[2]

                total_train += labels.size(0)
                correct_train += (torch.all(torch.eq(pred, labels))).sum().item()
                correct_train_pix += (torch.eq(pred, labels)).sum().item()
                loss = criterion(outputs, labels)
                loss_iter_train += loss.item()

            loss_train.append(loss_iter_train/len(task['train']))
            loss_val.append(loss_iter_val/len(task['test']))

            val_accuracy = 100 * correct_val / total_val
            val_accuracy_pix = 100 * correct_val_pix/(n_pixels_val)
            accuracy_val.append(val_accuracy)
            accuracy_val_pix.append(val_accuracy_pix)

            train_accuracy = 100 * correct_train / total_train
            train_accuracy_pix = 100 * correct_train_pix/(n_pixels_train)
            accuracy_train.append(train_accuracy)
            accuracy_train_pix.append(train_accuracy_pix)

            if verbose:
                print('\nEpoch: ['+str(epoch+1)+'/'+str(n_epoch)+']')
                print('Train loss is: {}'.format(loss_train[-1]))
                print('Validation loss is: {}'.format(loss_val[-1]))
                print('Train accuracy is: {} %'.format(accuracy_train[-1]))
                print('Train accuracy for pixels is: {} %'.format(accuracy_train_pix[-1]))
                print('Validation accuracy is: {} %'.format(accuracy_val[-1]))
                print('Validation accuracy for pixels is: {} %'.format(accuracy_val_pix[-1]))
            
            total_loss_train.append(loss_train)
            total_loss_val.append(loss_val)
            total_accuracy_train.append(accuracy_train)
            total_accuracy_train_pix.append(accuracy_train_pix)
            total_accuracy_val.append(accuracy_val)
            total_accuracy_val_pix.append(accuracy_val_pix)
            total_predictions.append(total_predictions)
            total_wrong_predictions.append(wrong_pred)
                
        metrics = {'loss_train': total_loss_train, 'loss_val': total_loss_val, 'accuracy_train':total_accuracy_train, 
                   'accuracy_train_pix': total_accuracy_train_pix, 'accuracy_val':total_accuracy_val, 
                   'accuracy_val_pix': total_accuracy_val_pix}
        final_pred = total_predictions
        
        return metrics, final_pred, total_wrong_predictions