def _forward(self, x, weights): ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2] x = pad_crop(x, self.largest_w, self.largest_h, input_w, input_h, goal='pad') x = x.unsqueeze(0) out = F.conv2d(x, weights["conv1.weight"], padding=1) out = F.relu(out) for i, layer in enumerate(self.layer1): out = layer._forward(1, i, weights, out) for i, layer in enumerate(self.layer2): out = layer._forward(2, i, weights, out) for i, layer in enumerate(self.layer3): out = layer._forward(3, i, weights, out) for i, layer in enumerate(self.layer4): out = layer._forward(4, i, weights, out) out = out.view(out.size(0), -1) out = F.linear(out, weights["linear.weight"], weights["linear.bias"]) x = torch.reshape(out, (1, 10, self.largest_w, self.largest_h)) return x
def forward(self, x): ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2] x = pad_crop(x, self.largest_w, self.largest_h, input_w, input_h, goal='pad') x = x.view(-1).unsqueeze(0) x = self.linear(x) x = torch.reshape(x, (1, 10, self.largest_w, self.largest_h)) x = pad_crop(x, input_w, input_h, self.largest_w, self.largest_h, goal='crop') return x
def forward(self, x): ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2] x = pad_crop(x, self.largest_w, self.largest_h, input_w, input_h, goal='pad') x = x.view(ch, -1).unsqueeze(0) x, _ = self.lstm(x) if self.attention_value is not None: x = torch.reshape(x, (1, x.shape[2], 10)) x = self.attention(x) x = torch.reshape(x, (1, 10, self.largest_w, self.largest_h)) x = pad_crop(x, input_w, input_h, self.largest_w, self.largest_h, goal='crop') return x
def forward(self, x): ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2] x = pad_crop(x, self.largest_w, self.largest_h, input_w, input_h, goal='pad') x = x.view(-1).unsqueeze(0) x = self.linear(x) x = F.leaky_relu(x) x = self.linear2(x) x = torch.reshape(x, (1, 10, self.out_dim[0], self.out_dim[1])) return x
def forward(self, x): ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2] x = x.unsqueeze(0) x = pad_crop(x, self.out_dim[0], self.out_dim[1], input_w, input_h, goal='crop') x = self.conv(x) x = self.conv2(x) if self.attention_value is not None: x = torch.reshape(x, (1, x.shape[2], x.shape[3], 10)) x = self.attention(x) x = torch.reshape(x, (1, 10, x.shape[1], x.shape[2])) return x
def forward(self, x): ch, input_w, input_h = x.shape[0], x.shape[1], x.shape[2] x = pad_crop(x, self.largest_w, self.largest_h, input_w, input_h, goal='pad') x = x.unsqueeze(0) out = F.relu(self.conv1(x)) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = out.view(out.size(0), -1) out = self.linear(out) x = torch.reshape(out, (1, 10, self.largest_w, self.largest_h)) return x
def train(self, tasks, model_name, criterion, n_epoch=30, lr=0.1, device = "cpu", verbose=True, inner_lr = 0.1, inner_iter = 10, meta_size = 50): """ trains the given model args: task: task to train the model on criterion: loss for train model_name: name of the model to train n_epoch: number of epochs for training lr: learning rate for the optimization device: whether to use CPU or GPU verbose: if set to True prints additional info inner_lr: if using a meta-learning algorithm, the learning rate of the inner loop inner_iter: if using a meta-learning algorithm, the iterations of the inner loop meta_size: if using a meta-learning algorithm, how big to set the meta samples size returns: the obtained metrics, the final predictions and the wrong ones """ total_loss_train = [] total_loss_val = [] total_accuracy_val = [] total_accuracy_train = [] total_accuracy_val_pix = [] total_accuracy_train_pix = [] total_predictions = [] total_wrong_predictions = [] criterion = criterion() sh1_big = 0 sh2_big = 0 for task in tasks: for i in range(len(task['train'])): sh1 = task['train'][i]['input'].shape[0] sh2 = task['train'][i]['input'].shape[1] if sh1 > sh1_big: sh1_big = sh1 if sh2 > sh2_big: sh2_big = sh2 for i in range(len(task['test'])): sh1 = task['test'][i]['input'].shape[0] sh2 = task['test'][i]['input'].shape[1] if sh1 > sh1_big: sh1_big = sh1 if sh2 > sh2_big: sh2_big = sh2 net = model_name(device, sh1_big, sh2_big).to(device) optimizer = Adam(net.parameters(), lr = lr) for epoch in tqdm(range(n_epoch)): losses = [] for task in tasks: loss_train = [] loss_val = [] accuracy_val = [] accuracy_train = [] accuracy_val_pix = [] accuracy_train_pix = [] inputs = [] outputs = [] for sample in task["train"]: x = Tensor(sample['input']) x = pad_crop(x, sh1_big, sh2_big, x.shape[0], x.shape[1], goal = "pad") inputs.append(FloatTensor(expand(x).float()).to(device)) y = Tensor(sample['output']) y = pad_crop(y, sh1_big, sh2_big, y.shape[0], y.shape[1], goal = "pad") outputs.append(LongTensor(y.long()).unsqueeze(0).to(device)) inputs_train = inputs[:meta_size] inputs_val = inputs[meta_size:] outputs_train = outputs[:meta_size] outputs_val = outputs[meta_size:] fast_weights = OrderedDict(net.named_parameters()) for _ in range(inner_iter): grads = [] loss = 0 for x,y in zip(inputs_train, outputs_train): logits = net._forward(x.to(device), fast_weights) loss += criterion(logits.to(device), y.to(device)) loss /= len(inputs_train) gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=True) fast_weights = OrderedDict((name, param - inner_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), gradients)) loss = 0 for x,y in zip(inputs_val, outputs_val): logits = net._forward(x.to(device), fast_weights) loss += criterion(logits.to(device), y.to(device)) loss /= len(inputs_val) loss.backward(retain_graph=True) losses.append(loss.float()) gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=True) net.train() optimizer.zero_grad() meta_loss = torch.stack(losses).mean() meta_loss.backward() optimizer.step() net.eval() with torch.no_grad(): correct_val = 0 correct_val_pix = 0 total_val = 0 loss_iter_val = 0 predictions = [] wrong_pred = [] n_pixels_val = 0 for sample in task['test']: img = FloatTensor(expand(sample['input'])).to(device) y = LongTensor(sample['output']) labels = pad_crop(y, sh1_big, sh2_big, y.shape[0], y.shape[1], "pad").unsqueeze(0).to(device) outputs = net(img) _, pred = torch.max(outputs.data, 1) predictions.append((img, pred)) n_pixels_val += pred.shape[1]*pred.shape[2] total_val += labels.size(0) flag = (torch.all(torch.eq(pred, labels))).sum().item() correct_val += flag if flag == 0: wrong_pred.append((img, pred)) correct_val_pix += (torch.eq(pred, labels)).sum().item() loss = criterion(outputs, labels) loss_iter_val += loss.item() correct_train = 0 correct_train_pix = 0 total_train = 0 loss_iter_train = 0 n_pixels_train = 0 for sample in task['train']: img = FloatTensor(expand(sample['input'])).to(device) y = LongTensor(sample['output']) labels = pad_crop(y, sh1_big, sh2_big, y.shape[0], y.shape[1], "pad").unsqueeze(0).to(device) outputs = net(img) _, pred = torch.max(outputs.data, 1) n_pixels_train += pred.shape[1]*pred.shape[2] total_train += labels.size(0) correct_train += (torch.all(torch.eq(pred, labels))).sum().item() correct_train_pix += (torch.eq(pred, labels)).sum().item() loss = criterion(outputs, labels) loss_iter_train += loss.item() loss_train.append(loss_iter_train/len(task['train'])) loss_val.append(loss_iter_val/len(task['test'])) val_accuracy = 100 * correct_val / total_val val_accuracy_pix = 100 * correct_val_pix/(n_pixels_val) accuracy_val.append(val_accuracy) accuracy_val_pix.append(val_accuracy_pix) train_accuracy = 100 * correct_train / total_train train_accuracy_pix = 100 * correct_train_pix/(n_pixels_train) accuracy_train.append(train_accuracy) accuracy_train_pix.append(train_accuracy_pix) if verbose: print('\nEpoch: ['+str(epoch+1)+'/'+str(n_epoch)+']') print('Train loss is: {}'.format(loss_train[-1])) print('Validation loss is: {}'.format(loss_val[-1])) print('Train accuracy is: {} %'.format(accuracy_train[-1])) print('Train accuracy for pixels is: {} %'.format(accuracy_train_pix[-1])) print('Validation accuracy is: {} %'.format(accuracy_val[-1])) print('Validation accuracy for pixels is: {} %'.format(accuracy_val_pix[-1])) total_loss_train.append(loss_train) total_loss_val.append(loss_val) total_accuracy_train.append(accuracy_train) total_accuracy_train_pix.append(accuracy_train_pix) total_accuracy_val.append(accuracy_val) total_accuracy_val_pix.append(accuracy_val_pix) total_predictions.append(total_predictions) total_wrong_predictions.append(wrong_pred) metrics = {'loss_train': total_loss_train, 'loss_val': total_loss_val, 'accuracy_train':total_accuracy_train, 'accuracy_train_pix': total_accuracy_train_pix, 'accuracy_val':total_accuracy_val, 'accuracy_val_pix': total_accuracy_val_pix} final_pred = total_predictions return metrics, final_pred, total_wrong_predictions