コード例 #1
0
    def _reconstruction_errors(images):
        latent = encoder(images, training=False)  # E(x)
        reconstructed_images = generator(latent, training=False)  # G(E(x))

        features = discriminator_features([images, latent], training=False)  # f_D(x, E(x))
        reconstructed_features = discriminator_features([reconstructed_images, latent], training=False)  # f_D(G(E(x)), E(x))

        pixel_distance = l1(images, reconstructed_images)  # L_R
        features_distance = l1(features, reconstructed_features)  # L_f_D

        return pixel_distance, features_distance
コード例 #2
0
    def validate(self, epoch, output_save):
        self.model.eval()
        batch_loss = 0.0
        batch_iou = 0.0
        vis_save = os.path.join(output_save, "epoch%02d" % (epoch+1))

        n_batches = len(self.dataloader_val)
        with torch.no_grad():
            for idx, sample in enumerate(self.dataloader_val):
                input = sample['occ_grid'].to(self.device)
                target_df = sample['df_gt'].to(self.device)
                names = sample['name']

                # ===================forward=====================
                output_df = self.model(input)
                loss = losses.l1(output_df, target_df, use_log_transform=False)
                iou = metric.iou_df(output_df, target_df, trunc_dist=1.0)

                # ===================log========================
                batch_loss += loss.item()
                batch_iou += iou

                # save the predictions at the end of the epoch
                # if epoch > args.save_epoch and (idx + 1) == n_batches-1:
                #     pred_dfs = output_df[:args.n_vis + 1]
                #     target_dfs = target_df[:args.n_vis + 1]
                #     names = names[:args.n_vis + 1]
                #     utils.save_predictions(vis_save, args.model_name, args.gt_type, names, pred_dfs=pred_dfs, target_dfs=target_dfs,
                #                            pred_occs=None, target_occs=None)

            val_loss = batch_loss / (idx + 1)
            mean_iou = batch_iou / (idx + 1)
            return val_loss, mean_iou
コード例 #3
0
    def train(self, epoch):
        self.model.train()
        batch_loss = 0.0
        for idx, sample in enumerate(self.dataloader_train):
            input = sample['occ_grid'].to(self.device)
            target = sample['df_gt'].to(self.device)

            # zero the parameter gradients
            self.optimizer.zero_grad()

            # ===================forward=====================
            output = self.model(input)
            loss = losses.l1(output, target, use_log_transform=args.use_logweight)
            # ===================backward + optimize====================
            loss.backward()
            self.optimizer.step()

            # ===================log========================
            batch_loss += loss.item()

            # if (idx + 1) % 10 == 0:
            #     print('Training : [iter %d / epoch %d] loss: %.3f' % (idx + 1, epoch + 1, loss.item()))

        train_loss = batch_loss / (idx + 1)
        return train_loss
コード例 #4
0
ファイル: trainer_occ.py プロジェクト: ParikaGoel/generate3D
    def validate(self, epoch, output_save):
        self.model.eval()
        batch_loss_bce = 0.0
        batch_loss_l1 = 0.0
        batch_iou = 0.0
        vis_save = os.path.join(output_save, "epoch%02d" % (epoch + 1))

        n_batches = len(self.dataloader_val)
        with torch.no_grad():
            for idx, sample in enumerate(self.dataloader_val):
                input = sample['occ_grid'].to(self.device)
                target_occ = sample['occ_gt'].to(self.device)
                target_df = sample['occ_df_gt'].to(self.device)
                names = sample['name']

                # ===================forward=====================
                output_occ = self.model(input)
                loss_bce = losses.bce(output_occ, target_occ)

                # Convert occ to df to calculate l1 loss
                output_df = utils.occs_to_dfs(output_occ,
                                              trunc=args.truncation,
                                              pred=True)
                loss_l1 = losses.l1(output_df, target_df)
                iou = metric.iou_occ(output_occ, target_occ)

                # ===================log========================
                batch_loss_bce += loss_bce.item()
                batch_loss_l1 += loss_l1.item()
                batch_iou += iou

                # save the predictions at the end of the epoch
                # if epoch > args.save_epoch and (idx + 1) == n_batches-1:
                #     pred_occs = output_occ[:args.n_vis+1]
                #     target_occs = target_occ[:args.n_vis+1]
                #     names = names[:args.n_vis+1]
                #     utils.save_predictions(vis_save, args.model_name, args.gt_type, names, pred_dfs=None, target_dfs=None,
                #                            pred_occs=pred_occs, target_occs=target_occs)

            val_loss_bce = batch_loss_bce / (idx + 1)
            val_loss_l1 = batch_loss_l1 / (idx + 1)
            mean_iou = batch_iou / (idx + 1)
            return val_loss_bce, val_loss_l1, mean_iou
コード例 #5
0
def test(test_list):
    dataset_test = dataloader.DatasetLoad(data_list=test_list,
                                          truncation=args.truncation)
    if args.model_name == 'Net3D':
        model = Net3D(1, 1).to(device)
    elif args.model_name == 'UNet3D':
        model = UNet3D(1, 1).to(device)

    dataloader_test = torchdata.DataLoader(dataset_test,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=2,
                                           drop_last=False)

    # load our saved model and use it to predict the class for test images
    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device)
    model.eval()

    vis_save = "%sfinal_results/vis/%s/%s" % (params["network_output"],
                                              args.model_name, args.gt_type)
    pathlib.Path(vis_save).mkdir(parents=True, exist_ok=True)

    batch_l1 = 0.0
    batch_iou = 0.0
    n_batches = len(dataloader_test)
    with torch.no_grad():
        for idx, sample in enumerate(dataloader_test):
            input = sample['occ_grid'].to(device)
            names = sample['name']

            if args.gt_type == 'occ':
                target_occ = sample['occ_gt'].to(device)
                target_df = sample['occ_df_gt'].to(device)

                # ===================forward=====================
                output_occ = model(input)

                # Convert occ to df to calculate l1 loss
                output_df = utils.occs_to_dfs(output_occ,
                                              trunc=args.truncation,
                                              pred=True)
                l1 = losses.l1(output_df, target_df)
                iou = metric.iou_occ(output_occ, target_occ)

                # save the predictions
                if (idx + 1) > n_batches - 2:
                    pred_occs = output_occ[:args.n_vis + 1]
                    target_occs = target_occ[:args.n_vis + 1]
                    names = names[:args.n_vis + 1]
                    utils.save_predictions(vis_save,
                                           args.model_name,
                                           args.gt_type,
                                           names,
                                           pred_dfs=None,
                                           target_dfs=None,
                                           pred_occs=pred_occs,
                                           target_occs=target_occs)
            else:
                target_df = sample['df_gt'].to(device)

                output_df = model(input)
                l1 = losses.l1(output_df, target_df)
                iou = metric.iou_df(output_df, target_df, trunc_dist=1.0)

                # save the predictions
                if (idx + 1) > n_batches - 2:
                    pred_dfs = output_df[:args.n_vis + 1]
                    target_dfs = target_df[:args.n_vis + 1]
                    names = names[:args.n_vis + 1]
                    utils.save_predictions(vis_save,
                                           args.model_name,
                                           args.gt_type,
                                           names,
                                           pred_dfs=pred_dfs,
                                           target_dfs=target_dfs,
                                           pred_occs=None,
                                           target_occs=None)

            batch_l1 += l1.item()
            batch_iou += iou

        l1_error = batch_l1 / (idx + 1)
        mean_iou = batch_iou / (idx + 1)

        print("Mean IOU: ", mean_iou)
        print("L1 Error: ", l1_error)