Ejemplo n.º 1
0
    def forward(self, input_dict):
        # Segmentation
        z = self.encode(input_dict["img"])
        mask = self.decode(z)

        # Classifier
        x = self.classify(input_dict["img"], z)
        return DeviceDict({"mask": mask, "species": x})
Ejemplo n.º 2
0
 def forward(self, input_dict):
     x = input_dict["img"]
     x = F.leaky_relu(self.conv1(x))
     x = self.pool(x)
     x = F.leaky_relu(self.conv2(x))
     x = self.pool(x)
     x = x.view(-1, 264 * 98 * 98)
     x = F.leaky_relu(self.fc1(x))
     x = F.leaky_relu(self.fc2(x))
     return DeviceDict({"species": x})
Ejemplo n.º 3
0
    def forward(self, input_dict):
        z = input_dict["z"]
        z = z.view(-1, 10 * 99 * 99)
        z = self.fc1(z)

        x = input_dict["img"]
        x = self.pool2(F.leaky_relu(self.conv1(x)))
        x = self.pool2(F.leaky_relu(self.conv2(x)))
        x = self.pool(F.leaky_relu(self.conv3(x)))

        x = x.view(-1, 3 * 20 * 20)
        x = torch.cat((x, z), dim=1)

        x = F.leaky_relu(self.fc2(x))
        x = F.leaky_relu(self.fc3(x))
        return DeviceDict({"id": x})
Ejemplo n.º 4
0
from dataset import Killer_Whale_Dataset, DeviceDict

# Initialize dataset
transform = transforms.Compose([transforms.ToTensor()])
path = "data/"
dataset = Killer_Whale_Dataset(path, transform=transform)

# Split into training and validation sets
trainidx = 0
validx = int(math.floor(len(dataset) * 0.8))

train_set = torch.utils.data.Subset(dataset, list(range(0, validx)))
val_set = torch.utils.data.Subset(dataset, list(range(validx, len(dataset))))

# collate_fn_device allows us to preserve custom dictionary when fetching a batch
collate_fn_device = lambda batch: DeviceDict(
    torch.utils.data.dataloader.default_collate(batch))
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=4,
                                           num_workers=0,
                                           pin_memory=False,
                                           shuffle=True,
                                           drop_last=True,
                                           collate_fn=collate_fn_device)
validation_loader = torch.utils.data.DataLoader(val_set,
                                                batch_size=4,
                                                num_workers=0,
                                                pin_memory=False,
                                                shuffle=True,
                                                drop_last=True,
                                                collate_fn=collate_fn_device)
Ejemplo n.º 5
0
 def forward(self, input_dict):
     z = self.encode(input_dict["img"])
     mask = self.decode(z)
     return DeviceDict({"mask": mask, "z": z})