def run(self): trainset = COCOStuff(Path(self.config["dataset path"], "train"), is_cropped=self.config["is cropped"], crop_size=(self.config["img width"], self.config["img height"]), in_memory=self.config["in memory"]) train_loader = DataLoader(dataset=trainset, batch_size=self.config["batch size"], shuffle=True) valset = COCOStuff(Path(self.config["dataset path"], "val"), is_cropped=self.config["is cropped"], crop_size=(self.config["img width"], self.config["img height"]), in_memory=self.config["in memory"]) val_loader = DataLoader(dataset=valset, batch_size=self.config["batch size"]) net_module = importlib.import_module( ("lib.models.{}".format(self.config["model"]))) net = getattr(net_module, "build_" + self.config["model"]) model = net(n_classes=self.N_CLASSES).to(self.device) start_epochs = self._load_checkpoint(model) optimizer = torch.optim.Adam(model.parameters(), lr=self.config["learning rate"]) for epoch in tqdm(range(start_epochs, self.config["epochs"])): model.train() total_loss = 0 for X, Y in tqdm(train_loader): X, Y = X.to(self.device), Y.long().to(self.device) Y_ = model(X) loss = cross_entropy2d(Y_, Y) total_loss += loss.item() loss.backward() optimizer.step() optimizer.zero_grad() avg_loss = total_loss / len(train_loader) self.logger.log("epoch", epoch, "loss", avg_loss) model.eval() ious = [] with torch.no_grad(): for X, Y in val_loader: X, Y = X.to(self.device), Y.long().to(self.device) Y_ = model(X) _, predicted = torch.max(Y_.data, 1) iou = get_iou(predicted, Y) ious.append(iou.item()) mean_iou = mean(ious) self.logger.log("epoch", epoch, "iou", mean_iou) self.logger.graph() self._save_checkpoint(epoch, model, retain=True)
def run(self): # Training dataset trainset = COCOSingleStuffF1( Path(self.config["dataset path"], "train"), threshold=self.config["threshold"]) train_loader = DataLoader( dataset=trainset, batch_size=self.config["batch size"], shuffle=True) # Validation dataset valset = COCOSingleStuffF1( Path(self.config["dataset path"], "val"), threshold=self.config["threshold"]) val_loader = DataLoader( dataset=valset, batch_size=self.config["batch size"]) # Model model = SegNetF1(n_classes=self.N_CLASSES).to(self.device) start_epochs = self._load_checkpoint(model) # Constants num_positives = train_loader.dataset.num_positives # Primal variables tau = torch.rand( len(train_loader.dataset), device=self.device, requires_grad=True) eps = torch.rand(1, device=self.device, requires_grad=True) w = torch.rand(1, device=self.device, requires_grad=True) # Dual variables lamb = torch.zeros(len(train_loader.dataset), device=self.device) lamb.fill_(0.001) mu = torch.zeros(1, device=self.device) mu.fill_(0.001) gamma = torch.zeros(len(train_loader.dataset), device=self.device) gamma.fill_(0.001) # Primal Optimization var_list = [{ "params": model.parameters(), "lr": self.config["learning rate"] }, { "params": tau, "lr": self.config["eta_tau"] }, { "params": eps, "lr": self.config["eta_eps"] }, { "params": w, "lr": self.config["eta_w"] }] optimizer = torch.optim.SGD(var_list) # Dataset iterator train_iter = iter(train_loader) # Count epochs and steps epochs = 0 step = 0 # Cache losses total_loss = 0 total_t1_loss = 0 total_t2_loss = 0 for outer in tqdm(range(start_epochs, self.config["n_outer"])): model.train() for inner in tqdm(range(self.config["n_inner"])): step += 1 # Sample try: X, y0, y1, i = next(train_iter) except StopIteration: train_iter = iter(train_loader) X, y0, y1, i = next(train_iter) # Forward computation X, y0 = X.to(self.device), y0.long().to(self.device) y1, i = y1.to(self.device), i.to(self.device) y0_, y1_ = model(X) # Compute loss t1_loss = cross_entropy2d(y0_, y0) t2_loss = lagrange(num_positives, y1_, y1, w, eps, tau[i], lamb[i], mu, gamma[i]) loss = t1_loss + (self.config["beta"] * t2_loss) # Store losses for logging total_loss += loss.item() total_t1_loss += t1_loss.item() total_t2_loss += t2_loss.item() # Backpropagate loss.backward() optimizer.step() optimizer.zero_grad() # Project eps to ensure non-negativity eps.data = torch.max( torch.zeros(1, dtype=torch.float, device=self.device), eps.data) # Log and validate per epoch if (step + 1) % len(train_loader) == 0: epochs += 1 # Log loss avg_loss = total_loss / len(train_loader) avg_t1_loss = total_t1_loss / len(train_loader) avg_t2_loss = total_t2_loss / len(train_loader) total_loss = 0 total_t1_loss = 0 total_t2_loss = 0 self.logger.log("epochs", epochs, "loss", avg_loss) self.logger.log("epochs", epochs, "t1loss", avg_t1_loss) self.logger.log("epochs", epochs, "t2loss", avg_t2_loss) # Validate model.eval() ious = [] with torch.no_grad(): for X, y0, _, _ in val_loader: X, y0 = X.to(self.device), y0.long().to( self.device) y0_, _ = model(X) _, predicted = torch.max(y0_.data, 1) iou = get_iou(predicted, y0) ious.append(iou.item()) # Log mean IOU mean_iou = mean(ious) self.logger.log("epochs", epochs, "mean_iou", mean_iou) # Graph self.logger.graph() # Checkpoint self._save_checkpoint(epochs, model) # Dual Updates with torch.no_grad(): mu_cache = 0 lamb_cache = torch.zeros_like(lamb) gamma_cache = torch.zeros_like(gamma) for X, y0, y1, i in tqdm(train_loader): # Forward computation X, y0 = X.to(self.device), y0.long().to(self.device) y1, i = y1.to(self.device), i.to(self.device) y0_, y1_ = model(X) # Cache for mu update mu_cache += tau[i].sum() # Lambda and gamma updates y1 = y1.float() y1_ = y1_.view(-1) lamb_cache[i] += ( self.config["eta_lamb"] * (y1 * (tau[i] - (w * y1_)))) gamma_cache[i] += ( self.config["eta_gamma"] * (y1 * (tau[i] - eps))) # Update data mu.data += self.config["eta_mu"] * (mu_cache - 1) lamb.data += lamb_cache gamma.data += gamma_cache