Exemplo n.º 1
0
    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (data, target) in enumerate(self.data_loader):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            self.model(data, torch.unsqueeze(target, 1))

            elbo_loss, output = self.elbo(target)

            reg_loss = l2_regularisation(self.model.posterior) + \
                l2_regularisation(self.model.prior) + \
                l2_regularisation(self.model.fcomb.layers)

            loss = -elbo_loss + 1e-5 * reg_loss
            loss.backward()
            self.optimizer.step()

            # self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output, target))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))
                self.writer.add_image(
                    'input', make_grid(data.cpu(), nrow=8, normalize=True))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log
Exemplo n.º 2
0
net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
epochs = 10  # 训练周期


# training
for epoch in range(epochs):
    print("Epoch {}".format(epoch))
    for step, (patch, mask, _) in enumerate(train_loader): 
        patch = patch.to(device)
        mask = mask.to(device)
        mask = torch.unsqueeze(mask,1)
        net.forward(patch, mask, training=True)
        elbo = net.elbo(mask)
        reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers)
        loss = -elbo + 1e-5 * reg_loss
        if step%100 == 0:
            print("-- [step {}] reg_loss: {}, loss: {}".format(step, reg_loss, loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # evaluate



# save the trained net model
print("saving the trained net model")
save_model(net, path='model/unet_3.pt')
Exemplo n.º 3
0
 def loss(self, mask):
     elbo = self.elbo(mask)
     reg_loss = l2_regularisation(self.posterior) + l2_regularisation(
         self.prior) + l2_regularisation(self.fcomb.layers)
     loss = -elbo + 1e-5 * reg_loss
     return loss
Exemplo n.º 4
0
 }
 train_targetLosses = []
 train_count = 0
 for idx, data in enumerate(train_loader):
     # print("Epoch:", epoch, "idx:", idx)
     inp = data["input"][0].cuda()
     gt = data["gt"][0].cuda()
     targetLoss = torch.nn.L1Loss()(inp, gt)
     print("Target Loss:", targetLoss.item())
     # Extremely important to protect from initial KL collapse
     if (torch.isnan(targetLoss)):
         continue
     net.forward(inp, gt, training=True)
     reconLoss, klLoss = net.elbo(gt)
     elbo = -(reconLoss + 10.0 * klLoss)
     l2posterior = l2_regularisation(net.posterior)
     l2prior = l2_regularisation(net.prior)
     l2fcomb = l2_regularisation(net.fcomb.layers)
     reg_loss = l2posterior + l2prior + l2fcomb
     loss = -elbo + 1e-5 * reg_loss
     if (loss.item() > 100000):
         continue
     print("Total Loss: ", loss.item())
     train_losses['rec'].append(reconLoss.item())
     train_losses['kl'].append(klLoss.item())
     train_losses['l2pos'].append(l2posterior.item())
     train_losses['l2pri'].append(l2prior.item())
     train_losses['l2fcom'].append(l2fcomb.item())
     train_losses['total'].append(loss.item())
     train_targetLosses.append(targetLoss.item())
     train_count += 1
Exemplo n.º 5
0
def train(args):
    num_epoch = args.epoch
    learning_rate = args.learning_rate
    task_dir = args.task
    
    trainset = MedicalDataset(task_dir=task_dir, mode='train' )
    validset = MedicalDataset(task_dir=task_dir, mode='valid')

    model =  ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    #summary(model, (1,320,320))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
    criterion = torch.nn.BCELoss()

    for epoch in range(num_epoch):
        model.train()
        while trainset.iteration < args.iteration:
            x, y = trainset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #print(x.size(), y.size())
            #output = torch.nn.Sigmoid()(model(x))
            model.forward(x,y,training=True)
            elbo = model.elbo(y)

            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
            loss = -elbo + 1e-5 * reg_loss
            #loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        trainset.iteration = 0

        model.eval()
        with torch.no_grad():
            while validset.iteration < args.test_iteration:
                x, y = validset.next()
                x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
                #output = torch.nn.Sigmoid()(model(x, y))
                model.forward(x,y,training=True)
                elbo = model.elbo(y)

                reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
                valid_loss = -elbo + 1e-5 * reg_loss
            validset.iteration = 0
                
        print('Epoch: {}, elbo: {:.4f}, regloss: {:.4f}, loss: {:.4f}, valid loss: {:.4f}'.format(epoch+1, elbo.item(), reg_loss.item(), loss.item(), valid_loss.item()))
        """
        #Logger
         # 1. Log scalar values (scalar summary)
        info = { 'loss': loss.item(), 'accuracy': valid_loss.item() }

        for tag, value in info.items():
            Logger.scalar_summary(tag, value, epoch+1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            Logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1)
            Logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
        """
    torch.save(model.state_dict(), './save/'+trainset.task_dir+'model.pth')
Exemplo n.º 6
0
        images, gts, depths, grays, index_batch = pack
        # print(index_batch)
        images = Variable(images)
        gts = Variable(gts)
        depths = Variable(depths)
        grays = Variable(grays)
        images = images.cuda()
        gts = gts.cuda()
        depths = depths.cuda()
        grays = grays.cuda()

        pred_post, pred_prior, latent_loss, depth_pred_post, depth_pred_prior = generator.forward(
            images, depths, gts)

        ## l2 regularizer the inference model
        reg_loss = l2_regularisation(generator.xy_encoder) + \
                l2_regularisation(generator.x_encoder) + l2_regularisation(generator.sal_encoder)
        smoothLoss_post = opt.sm_weight * smooth_loss(torch.sigmoid(pred_post),
                                                      gts)
        reg_loss = opt.reg_weight * reg_loss
        latent_loss = latent_loss
        depth_loss_post = opt.depth_loss_weight * mse_loss(
            torch.sigmoid(depth_pred_post), depths)
        sal_loss = structure_loss(pred_post,
                                  gts) + smoothLoss_post + depth_loss_post
        anneal_reg = linear_annealing(0, 1, epoch, opt.epoch)
        latent_loss = opt.lat_weight * anneal_reg * latent_loss
        gen_loss_cvae = sal_loss + latent_loss
        gen_loss_cvae = opt.vae_loss_weight * gen_loss_cvae

        smoothLoss_prior = opt.sm_weight * smooth_loss(