Beispiel #1
0
def test(epoch, iter, log_file_semantic_val, log_file_scan_val, val_file,
         log_file_2d_val):
    test_loss_semantic = []  # To store semantic loss at each iteration
    test_loss_scan = []  # To store scan loss at each iteration
    test_loss_2d = []
    model.eval()
    model2d_fixed.eval()
    model2d_trainable.eval()
    if opt.use_proxy_loss:
        model2d_classifier.eval()
    start = time.time()

    # h5py has too much data. 10000 samples are too much to use. Divide by 10 and pick 1000 at a time
    print('Validating on %s' % val_file)
    for h5py_index in range(10):
        volumes, labels, frames, world_to_grids = data_util.load_hdf5_data(
            val_file, num_classes, h5py_index)
        frames = frames[:, :2 + num_images]
        volumes = volumes.permute(0, 1, 4, 3, 2)
        labels = labels.permute(0, 1, 4, 3, 2)
        labels = labels[:, 0, :, grid_centerX,
                        grid_centerY]  # center columns as targets

        # Filter out the scenes not available
        available_frames_index = data_util.get_available_frames_id(
            opt.data_path_2d, frames)
        if len(available_frames_index) < batch_size:
            continue
        volumes = volumes[available_frames_index]
        labels = labels[available_frames_index]
        frames = frames[available_frames_index]
        world_to_grids = world_to_grids[available_frames_index]

        num_samples = volumes.shape[0]
        # shuffle
        indices = torch.randperm(num_samples).long().split(batch_size)
        # remove last mini-batch so that all the batches have equal size
        indices = indices[:-1]

        with torch.no_grad():
            if CUDA_AVAILABLE:
                mask_semantic = torch.cuda.LongTensor(batch_size *
                                                      column_height)
                depth_images = torch.cuda.FloatTensor(batch_size * num_images,
                                                      proj_image_dims[1],
                                                      proj_image_dims[0])
                color_images = torch.cuda.FloatTensor(batch_size * num_images,
                                                      3, input_image_dims[1],
                                                      input_image_dims[0])
                camera_poses = torch.cuda.FloatTensor(batch_size * num_images,
                                                      4, 4)
                label_images = torch.cuda.LongTensor(batch_size * num_images,
                                                     proj_image_dims[1],
                                                     proj_image_dims[0])
            else:
                mask_semantic = torch.LongTensor(batch_size * column_height)
                depth_images = torch.FloatTensor(batch_size * num_images,
                                                 proj_image_dims[1],
                                                 proj_image_dims[0])
                color_images = torch.FloatTensor(batch_size * num_images, 3,
                                                 input_image_dims[1],
                                                 input_image_dims[0])
                camera_poses = torch.FloatTensor(batch_size * num_images, 4, 4)
                label_images = torch.LongTensor(batch_size * num_images,
                                                proj_image_dims[1],
                                                proj_image_dims[0])

            for t, v in enumerate(indices):
                # print(t, v)
                if CUDA_AVAILABLE:
                    targets_semantic = labels[v].cuda()
                else:
                    targets_semantic = labels[v]

                # Ignore Invalid targets for semantic
                mask_semantic = targets_semantic.view(-1).data.clone()
                for k in range(num_classes):
                    if criterion_weights_semantic[k] == 0:
                        mask_semantic[mask_semantic.eq(k)] = 0
                mask_indices_semantic = mask_semantic.nonzero().squeeze()
                if len(mask_indices_semantic.shape) == 0:
                    continue

                # Ignore Invalid targets for scan
                # Create mask from current volume where 1 represents voxel is known-free or known-occupied.
                # 0 input should target 0, 1 should 1 and 2(from before voxel discarding) should 2.
                if opt.train_scan_completion:
                    mask_scan = targets_semantic.view(-1).data.clone()
                    mask_scan[:] = 1
                    # Ignore Unknown Voxels from before.
                    mask_scan[targets_semantic.view(-1).eq(opt.num_classes -
                                                           1)] = 0
                    mask_scan_indices = mask_scan.nonzero().squeeze()
                    if len(mask_scan_indices.shape) == 0:
                        continue

                    # ToDo: What if you generate targets_scan from volumetric grid?
                    # ToDo: You should get the same result but confirm.
                    targets_scan = targets_semantic.view(-1).data.clone()
                    targets_scan[torch.ge(targets_scan, 1) *
                                 torch.lt(targets_scan, num_classes - 1)] = 1
                    targets_scan[torch.eq(targets_scan, num_classes -
                                          1)] = 2  # Label 41 with class 2

                transforms = world_to_grids[v].unsqueeze(1)
                transforms = transforms.expand(batch_size, num_images, 4,
                                               4).contiguous().view(-1, 4, 4)
                if CUDA_AVAILABLE:
                    transforms = transforms.cuda()

                # get 2d data
                is_load_success = data_util.load_frames_multi(
                    opt.data_path_2d, frames[v], depth_images, color_images,
                    camera_poses, color_mean, color_std)
                if not is_load_success:
                    continue

                # 3d Input
                volume = volumes[v]
                # Get indices of voxels to be removed if training scan completion
                random_center_voxel_indices = torch.Tensor()  # Empty Tensor
                if opt.train_scan_completion:
                    # ToDo: For all sample in each batch, same random voxels are removed.
                    # ToDo: Voxel already unknown also gets removed.
                    random_center_voxel_indices = projection.get_random_center_voxels_index(
                        opt.voxel_removal_fraction)
                    # Mark the 3D voxels as Unknown and
                    volume[:, :, random_center_voxel_indices,
                           projection.volume_dims[0] // 2,
                           projection.volume_dims[1] // 2] = 0

                # Compute projection mapping and mark center voxels as Unknown if training for scan completion
                proj_mapping = [
                    projection.compute_projection(d, c, t,
                                                  random_center_voxel_indices)
                    for d, c, t in zip(depth_images, camera_poses, transforms)
                ]
                if None in proj_mapping:
                    print('No mapping in proj_mapping')
                    continue
                proj_mapping = list(zip(*proj_mapping))
                proj_ind_3d = torch.stack(proj_mapping[0])
                proj_ind_2d = torch.stack(proj_mapping[1])

                if opt.use_proxy_loss:
                    data_util.load_label_frames(opt.data_path_2d, frames[v],
                                                label_images, num_classes)
                    mask2d = label_images.view(-1).clone()
                    for k in range(num_classes):
                        if criterion_weights_semantic[k] == 0:
                            mask2d[mask2d.eq(k)] = 0
                    mask2d = mask2d.nonzero().squeeze()
                    if len(mask2d.shape) == 0:
                        continue  # nothing to optimize for here

                # 2d
                imageft_fixed = model2d_fixed(color_images)
                imageft = model2d_trainable(imageft_fixed)
                if opt.use_proxy_loss:
                    ft2d = model2d_classifier(imageft)
                    ft2d = ft2d.permute(0, 2, 3, 1).contiguous()

                # 2d/3d
                if CUDA_AVAILABLE:
                    input3d = volume.cuda()
                else:
                    input3d = volume

                # Forward Pass Only
                output_semantic, output_scan = model(input3d, imageft,
                                                     proj_ind_3d, proj_ind_2d,
                                                     grid_dims)

                # Compute Scan and semantic Loss
                loss_semantic = criterion_semantic(
                    output_semantic.view(-1, num_classes),
                    targets_semantic.view(-1))
                test_loss_semantic.append(loss_semantic.item())
                if opt.train_scan_completion:
                    loss_scan = criterion_scan(
                        output_scan.view(-1, _NUM_OCCUPANCY_STATES),
                        targets_scan.view(-1))
                    test_loss_scan.append(loss_scan.item())

                if opt.use_proxy_loss:
                    loss2d = criterion2d(ft2d.view(-1, num_classes),
                                         label_images.view(-1))
                    test_loss_2d.append(loss2d.item())
                    # Confusion
                    y = ft2d.data
                    y = y.view(-1, num_classes)[:, :-1]
                    _, predictions = y.max(1)
                    predictions = predictions.view(-1)
                    k = label_images.view(-1)
                    confusion2d_val.add(
                        torch.index_select(predictions, 0, mask2d),
                        torch.index_select(k, 0, mask2d))

                # Confusion for Semantic
                y = output_semantic.data
                y = y.view(y.nelement() // y.size(2), num_classes)[:, :-1]
                _, predictions = y.max(1)
                predictions = predictions.view(-1)
                k = targets_semantic.data.view(-1)
                confusion_val.add(
                    torch.index_select(predictions, 0, mask_indices_semantic),
                    torch.index_select(k, 0, mask_indices_semantic))

                # Confusion for Scan completion
                if opt.train_scan_completion:
                    y = output_scan.data
                    # Discard semantic prediction of Unknown Voxels in target_scan
                    y = y.view(y.nelement() // y.size(2),
                               _NUM_OCCUPANCY_STATES)[:, :-1]
                    _, predictions_scan = y.max(1)
                    predictions_scan = predictions_scan.view(-1)
                    k = targets_scan.data.view(-1)
                    confusion_scan_val.add(
                        torch.index_select(predictions_scan, 0,
                                           mask_scan_indices),
                        torch.index_select(k, 0, mask_scan_indices))

    end = time.time()
    took = end - start
    evaluate_confusion(confusion_val, test_loss_semantic, epoch, iter, took,
                       'ValidationSemantic', log_file_semantic_val,
                       num_classes)
    if opt.train_scan_completion:
        evaluate_confusion(confusion_scan_val, test_loss_scan, epoch, iter,
                           took, 'ValidationScan', log_file_scan_val,
                           _NUM_OCCUPANCY_STATES)
    if opt.use_proxy_loss:
        evaluate_confusion(confusion2d_val, test_loss_2d, epoch, iter, took,
                           'Validation2d', log_file_2d_val, num_classes)
    return test_loss_semantic, test_loss_scan, test_loss_2d
def test(epoch, iter, log_file, val_file, log_file_2d):
    test_loss = []
    test_loss_2d = []
    model.eval()
    model2d_fixed.eval()
    model2d_trainable.eval()
    if opt.use_proxy_loss:
        model2d_classifier.eval()
    start = time.time()

    volumes, labels, frames, world_to_grids = data_util.load_hdf5_data(val_file, num_classes)
    frames = frames[:, :2+num_images]
    volumes = volumes.permute(0, 1, 4, 3, 2)
    labels = labels.permute(0, 1, 4, 3, 2)
    labels = labels[:, 0, :, grid_centerX, grid_centerY]  # center columns as targets
    num_samples = volumes.shape[0]
    # shuffle
    indices = torch.randperm(num_samples).long().split(batch_size)
    # remove last mini-batch so that all the batches have equal size
    indices = indices[:-1]

    with torch.no_grad():
        mask = torch.cuda.LongTensor(batch_size*column_height)
        depth_images = torch.cuda.FloatTensor(batch_size * num_images, proj_image_dims[1], proj_image_dims[0])
        color_images = torch.cuda.FloatTensor(batch_size * num_images, 3, input_image_dims[1], input_image_dims[0])
        camera_poses = torch.cuda.FloatTensor(batch_size * num_images, 4, 4)
        label_images = torch.cuda.LongTensor(batch_size * num_images, proj_image_dims[1], proj_image_dims[0])

        for t,v in enumerate(indices):
            targets = labels[v].cuda()
            # valid targets
            mask = targets.view(-1).data.clone()
            for k in range(num_classes):
                if criterion_weights[k] == 0:
                    mask[mask.eq(k)] = 0
            maskindices = mask.nonzero().squeeze()
            if len(maskindices.shape) == 0:
                continue

            transforms = world_to_grids[v].unsqueeze(1)
            transforms = transforms.expand(batch_size, num_images, 4, 4).contiguous().view(-1, 4, 4).cuda()
            # get 2d data
            data_util.load_frames_multi(opt.data_path_2d, frames[v], depth_images, color_images, camera_poses, color_mean, color_std)
            if opt.use_proxy_loss:
                data_util.load_label_frames(opt.data_path_2d, frames[v], label_images, num_classes)
                mask2d = label_images.view(-1).clone()
                for k in range(num_classes):
                    if criterion_weights[k] == 0:
                        mask2d[mask2d.eq(k)] = 0
                mask2d = mask2d.nonzero().squeeze()
                if (len(mask2d.shape) == 0):
                    continue  # nothing to optimize for here
            # compute projection mapping
            proj_mapping = [projection.compute_projection(d, c, t) for d, c, t in zip(depth_images, camera_poses, transforms)]
            if None in proj_mapping: #invalid sample
                #print '(invalid sample)'
                continue
            proj_mapping = zip(*proj_mapping)
            proj_ind_3d = torch.stack(proj_mapping[0])
            proj_ind_2d = torch.stack(proj_mapping[1])
            # 2d
            imageft_fixed = model2d_fixed(color_images)
            imageft = model2d_trainable(imageft_fixed)
            if opt.use_proxy_loss:
                ft2d = model2d_classifier(imageft)
                ft2d = ft2d.permute(0, 2, 3, 1).contiguous()
            # 2d/3d
            input3d = volumes[v].cuda()
            output = model(input3d, imageft, proj_ind_3d, proj_ind_2d, grid_dims)
            loss = criterion(output.view(-1, num_classes), targets.view(-1))
            test_loss.append(loss.item())
            if opt.use_proxy_loss:
                loss2d = criterion2d(ft2d.view(-1, num_classes), label_images.view(-1))
                test_loss_2d.append(loss2d.item())
                # confusion
                y = ft2d.data
                y = y.view(-1, num_classes)[:, :-1]
                _, predictions = y.max(1)
                predictions = predictions.view(-1)
                k = label_images.view(-1)
                confusion2d_val.add(torch.index_select(predictions, 0, mask2d), torch.index_select(k, 0, mask2d))
            
            # confusion
            y = output.data
            y = y.view(y.nelement()/y.size(2), num_classes)[:, :-1]
            _, predictions = y.max(1)
            predictions = predictions.view(-1)
            k = targets.data.view(-1)
            confusion_val.add(torch.index_select(predictions, 0, maskindices), torch.index_select(k, 0, maskindices))

    end = time.time()
    took = end - start
    evaluate_confusion(confusion_val, test_loss, epoch, iter, took, 'Test', log_file)
    if opt.use_proxy_loss:
         evaluate_confusion(confusion2d_val, test_loss_2d, epoch, iter, took, 'Test2d', log_file_2d)
    return test_loss, test_loss_2d
Beispiel #3
0
def train(epoch, iter, log_file_semantic, log_file_scan, train_file,
          log_file_2d):
    train_loss_semantic = []  # To store semantic loss at each iteration
    train_loss_scan = []  # To store scan loss at each iteration
    train_loss_2d = []
    model.train()
    start = time.time()
    model2d_trainable.train()
    if opt.use_proxy_loss:
        model2d_classifier.train()

    # h5py has too much data. 10000 samples are too much to use. Divide by 10 and pick 1000 at a time
    print('Training on %s' % train_file)
    for h5py_index in range(10):
        volumes, labels, frames, world_to_grids = data_util.load_hdf5_data(
            train_file, num_classes, h5py_index)
        frames = frames[:, :2 + num_images]
        volumes = volumes.permute(0, 1, 4, 3, 2)
        labels = labels.permute(0, 1, 4, 3, 2)
        labels = labels[:, 0, :, grid_centerX,
                        grid_centerY]  # center columns as targets

        # Filter out the scenes not available
        available_frames_index = data_util.get_available_frames_id(
            opt.data_path_2d, frames)
        if len(available_frames_index) < batch_size:
            continue
        volumes = volumes[available_frames_index]
        labels = labels[available_frames_index]
        frames = frames[available_frames_index]
        world_to_grids = world_to_grids[available_frames_index]

        num_samples = volumes.shape[0]
        # shuffle
        indices = torch.randperm(num_samples).long().split(batch_size)
        # remove last mini-batch so that all the batches have equal size
        indices = indices[:-1]

        if CUDA_AVAILABLE:
            mask = torch.cuda.LongTensor(batch_size * column_height)
            depth_images = torch.cuda.FloatTensor(batch_size * num_images,
                                                  proj_image_dims[1],
                                                  proj_image_dims[0])
            color_images = torch.cuda.FloatTensor(batch_size * num_images, 3,
                                                  input_image_dims[1],
                                                  input_image_dims[0])
            camera_poses = torch.cuda.FloatTensor(batch_size * num_images, 4,
                                                  4)
            label_images = torch.cuda.LongTensor(batch_size * num_images,
                                                 proj_image_dims[1],
                                                 proj_image_dims[0])
        else:
            mask = torch.LongTensor(batch_size * column_height)
            depth_images = torch.FloatTensor(batch_size * num_images,
                                             proj_image_dims[1],
                                             proj_image_dims[0])
            color_images = torch.FloatTensor(batch_size * num_images, 3,
                                             input_image_dims[1],
                                             input_image_dims[0])
            camera_poses = torch.FloatTensor(batch_size * num_images, 4, 4)
            label_images = torch.LongTensor(batch_size * num_images,
                                            proj_image_dims[1],
                                            proj_image_dims[0])

        for t, v in enumerate(indices):
            iter_start = time.time()
            # print(t, v)
            if CUDA_AVAILABLE:
                targets_semantic = torch.autograd.Variable(labels[v].cuda())
            else:
                targets_semantic = torch.autograd.Variable(labels[v])

            # Ignore Invalid targets for semantic
            mask_semantic = targets_semantic.view(-1).data.clone()
            for k in range(num_classes):
                if criterion_weights_semantic[k] == 0:
                    mask_semantic[mask_semantic.eq(k)] = 0
            mask_semantic_indices = mask_semantic.nonzero().squeeze(
            )  # Used in confusion matrix
            if len(mask_semantic_indices.shape) == 0:
                continue

            # Ignore Invalid targets for scan
            # occ[0] = np.less_equal(np.abs(sdfs), 1) # occupied space - 1, empty space - 0
            # occ[1] = np.greater_equal(sdfs, -1)     # known space = 1, unknown space - 0
            # Known-Free Space : 1, 0(2). - Target = 0
            # Known-Occupied Space : 1, 1 (3) - Target = 1
            # Unknown Space: 0, 0 - (0) - Target = 2
            # Create mask from current volume where 1 represents voxel is known-free or known-occupied.
            # ToDo: Ask tutor: What if I don't use a mask?
            # 0 input should target 0, 1 should 1 and 2(from before voxel discarding) should 2.
            if opt.train_scan_completion:
                mask_scan = targets_semantic.view(-1).data.clone()
                mask_scan[:] = 1
                # Ignore Unknown Voxels from before.
                mask_scan[targets_semantic.view(-1).eq(opt.num_classes -
                                                       1)] = 0
                mask_scan_indices = mask_scan.nonzero().squeeze()
                if len(mask_scan_indices.shape) == 0:
                    continue

                # ToDo: What if you generate targets_scan from volumetric grid?
                # ToDo: You should get the same result but confirm.
                targets_scan = targets_semantic.view(-1).data.clone()
                targets_scan[torch.ge(targets_scan, 1) *
                             torch.lt(targets_scan, num_classes - 1)] = 1
                targets_scan[torch.eq(targets_scan, num_classes -
                                      1)] = 2  # Label 41 with class 2

            transforms = world_to_grids[v].unsqueeze(1)
            transforms = transforms.expand(batch_size, num_images, 4,
                                           4).contiguous().view(-1, 4, 4)
            if CUDA_AVAILABLE:
                transforms = transforms.cuda()

            # Load the data
            # print("loading the data")
            is_load_success = data_util.load_frames_multi(
                opt.data_path_2d, frames[v], depth_images, color_images,
                camera_poses, color_mean, color_std)
            if not is_load_success:
                continue

            # 3d Input
            volume = volumes[v]
            # Get indices of voxels to be removed if training scan completion
            random_center_voxel_indices = torch.Tensor()  # Empty Tensor
            if opt.train_scan_completion:
                # ToDo: For all sample in each batch, same random voxels are removed.
                # ToDo: Voxel already unknown also gets removed.
                random_center_voxel_indices = projection.get_random_center_voxels_index(
                    opt.voxel_removal_fraction)
                # Mark the 3D voxels as Unknown and
                volume[:, :, random_center_voxel_indices,
                       projection.volume_dims[0] // 2,
                       projection.volume_dims[1] // 2] = 0

            # Compute projection mapping and mark center voxels as Unknown if training for scan completion
            proj_mapping = [
                projection.compute_projection(d, c, t,
                                              random_center_voxel_indices)
                for d, c, t in zip(depth_images, camera_poses, transforms)
            ]
            if None in proj_mapping:  # Invalid sample
                print('No mapping in proj_mapping')
                continue
            proj_mapping = list(zip(*proj_mapping))
            proj_ind_3d = torch.stack(proj_mapping[0])
            proj_ind_2d = torch.stack(proj_mapping[1])

            if opt.use_proxy_loss:
                data_util.load_label_frames(opt.data_path_2d, frames[v],
                                            label_images, num_classes)
                mask2d = label_images.view(-1).clone()
                for k in range(num_classes):
                    if criterion_weights_semantic[k] == 0:
                        mask2d[mask2d.eq(k)] = 0
                mask2d = mask2d.nonzero().squeeze()
                if len(mask2d.shape) == 0:
                    continue  # nothing to optimize for here

            # 2d
            imageft_fixed = model2d_fixed(
                torch.autograd.Variable(color_images))
            imageft = model2d_trainable(imageft_fixed)
            if opt.use_proxy_loss:
                ft2d = model2d_classifier(imageft)
                ft2d = ft2d.permute(0, 2, 3, 1).contiguous()

            # 2d/3d
            input3d = torch.autograd.Variable(volume)
            if CUDA_AVAILABLE:
                input3d = input3d.cuda()

            # Forward Pass
            output_semantic, output_scan = model(
                input3d, imageft, torch.autograd.Variable(proj_ind_3d),
                torch.autograd.Variable(proj_ind_2d), grid_dims)

            # Display Once GPU memory usage - Be Sure of GPU usage. Collab is a bit unpredictable
            check_gpu_memory_usage_once()

            # Compute Scan and semantic Loss
            loss_semantic = criterion_semantic(
                output_semantic.view(-1, num_classes),
                targets_semantic.view(-1))
            train_loss_semantic.append(loss_semantic.item())
            if opt.train_scan_completion:
                loss_scan = criterion_scan(
                    output_scan.view(-1, _NUM_OCCUPANCY_STATES),
                    targets_scan.view(-1))
                train_loss_scan.append(loss_scan.item())
                loss = loss_scan + loss_semantic
            else:
                loss = loss_semantic

            # Backpropagate total loss.
            # ToDo: Note using same optimizer for both branches. Is there a need for different optimizers?
            optimizer.zero_grad()
            optimizer2d.zero_grad()
            if opt.use_proxy_loss:
                loss.backward(retain_graph=True)
            else:
                loss.backward()
            optimizer.step()
            # optimizer2d.step is probably required even when use_proxy_loss is False, since backprojection layer is
            # differentiable, allowing us to backpropagate the gradients to 2d model from model(3D).
            optimizer2d.step()

            # ToDo: Check if proxy loss is required. If optimizer2d is injecting gradients, proxy loss may be needed.
            if opt.use_proxy_loss:
                loss2d = criterion2d(
                    ft2d.view(-1, num_classes),
                    torch.autograd.Variable(label_images.view(-1)))
                train_loss_2d.append(loss2d.item())
                optimizer2d.zero_grad()
                optimizer2dc.zero_grad()
                loss2d.backward()
                optimizer2dc.step()
                optimizer2d.step()
                # confusion
                y = ft2d.data
                y = y.view(-1, num_classes)[:, :-1]
                _, predictions = y.max(1)
                predictions = predictions.view(-1)
                k = label_images.view(-1)
                confusion2d.add(torch.index_select(predictions, 0, mask2d),
                                torch.index_select(k, 0, mask2d))

            # Confusion for Semantic
            y = output_semantic.data
            # Discard semantic prediction of class num_classes-1[Unknown Voxel]
            y = y.view(y.nelement() // y.size(2), num_classes)[:, :-1]
            _, predictions = y.max(1)
            predictions = predictions.view(-1)
            k = targets_semantic.data.view(-1)
            confusion.add(
                torch.index_select(predictions, 0, mask_semantic_indices),
                torch.index_select(k, 0, mask_semantic_indices))

            # Confusion for Scan completion
            if opt.train_scan_completion:
                y = output_scan.data
                # Discard semantic prediction of Unknown Voxels in target_scan
                y = y.view(y.nelement() // y.size(2),
                           _NUM_OCCUPANCY_STATES)[:, :-1]
                _, predictions_scan = y.max(1)
                predictions_scan = predictions_scan.view(-1)
                k = targets_scan.data.view(-1)
                confusion_scan.add(
                    torch.index_select(predictions_scan, 0, mask_scan_indices),
                    torch.index_select(k, 0, mask_scan_indices))

            # Log loss for current iteration and print every 20th turn.
            msg1 = _SPLITTER.join(
                [str(f)
                 for f in [epoch, iter, loss_semantic.item()]])
            log_file_semantic.write(msg1 + '\n')
            if opt.train_scan_completion:
                msg2 = _SPLITTER.join(
                    [str(f)
                     for f in [epoch, iter, loss_scan.item()]])
                log_file_scan.write(msg2 + '\n')

            # InFrequent logging stops chrome from crash[Colab] and also less strain on jupyter.
            if iter % (64 // batch_size) == 0:
                print("Semantic: %s, %0.6f" % (msg1, time.time() - iter_start))
                if opt.train_scan_completion:
                    print("Scan    : %s" % msg2)

            iter += 1
            if iter % (
                    10000 // batch_size
            ) == 0:  # Save more frequently, since its Google Collaboratory.
                # Save 3d model
                if not opt.train_scan_completion:
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            opt.output,
                            'model-semantic-epoch%s-iter%s-Sem%s.pth' %
                            (epoch, iter, str(loss_semantic.item()))))
                else:
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            opt.output,
                            'model-semantic_and_scan-epoch%s-iter%s-sem%s-scan%s.pth'
                            % (epoch, iter, str(
                                loss_semantic.item()), str(loss_scan.item()))))
                # Save 2d model
                # Important ToDo: Do we need to retrain on model2d_trainable
                torch.save(
                    model2d_trainable.state_dict(),
                    os.path.join(opt.output,
                                 'model2d-iter%s-epoch%s.pth' % (iter, epoch)))
                if opt.use_proxy_loss:
                    torch.save(
                        model2d_classifier.state_dict(),
                        os.path.join(
                            opt.output,
                            'model2dc-iter%s-epoch%s.pth' % (iter, epoch)))
            if iter == 1:
                torch.save(model2d_fixed.state_dict(),
                           os.path.join(opt.output, 'model2dfixed.pth'))

            if iter % 100 == 0:
                evaluate_confusion(confusion, train_loss_semantic, epoch, iter,
                                   -1, 'TrainSemantic', log_file_semantic,
                                   num_classes)
                if opt.train_scan_completion:
                    evaluate_confusion(confusion_scan, train_loss_scan, epoch,
                                       iter, -1, 'TrainScan', log_file_scan,
                                       _NUM_OCCUPANCY_STATES)
                if opt.use_proxy_loss:
                    evaluate_confusion(confusion2d, train_loss_2d, epoch, iter,
                                       -1, 'Train2d', log_file_2d, num_classes)

    end = time.time()
    took = end - start
    evaluate_confusion(confusion, train_loss_semantic, epoch, iter, took,
                       'TrainSemantic', log_file_semantic, num_classes)
    if opt.train_scan_completion:
        evaluate_confusion(confusion_scan, train_loss_scan, epoch, iter, took,
                           'TrainScan', log_file_scan, _NUM_OCCUPANCY_STATES)
    if opt.use_proxy_loss:
        evaluate_confusion(confusion2d, train_loss_2d, epoch, iter, took,
                           'Train2d', log_file_2d, num_classes)
    return train_loss_semantic, train_loss_scan, iter, train_loss_2d
def train(epoch, iter, log_file, train_file, log_file_2d):
    train_loss = []
    train_loss_2d = []
    model.train()
    start = time.time()
    model2d_trainable.train()
    if opt.use_proxy_loss:
        model2d_classifier.train()

    volumes, labels, frames, world_to_grids = data_util.load_hdf5_data(train_file, num_classes)
    frames = frames[:, :2+num_images]
    volumes = volumes.permute(0, 1, 4, 3, 2)
    labels = labels.permute(0, 1, 4, 3, 2)

    labels = labels[:, 0, :, grid_centerX, grid_centerY]  # center columns as targets
    num_samples = volumes.shape[0]
    # shuffle
    indices = torch.randperm(num_samples).long().split(batch_size)
    # remove last mini-batch so that all the batches have equal size
    indices = indices[:-1]

    mask = torch.cuda.LongTensor(batch_size*column_height)
    depth_images = torch.cuda.FloatTensor(batch_size * num_images, proj_image_dims[1], proj_image_dims[0])
    color_images = torch.cuda.FloatTensor(batch_size * num_images, 3, input_image_dims[1], input_image_dims[0])
    camera_poses = torch.cuda.FloatTensor(batch_size * num_images, 4, 4)
    label_images = torch.cuda.LongTensor(batch_size * num_images, proj_image_dims[1], proj_image_dims[0])

    for t,v in enumerate(indices):
        targets = torch.autograd.Variable(labels[v].cuda())
        # valid targets
        mask = targets.view(-1).data.clone()
        for k in range(num_classes):
            if criterion_weights[k] == 0:
                mask[mask.eq(k)] = 0
        maskindices = mask.nonzero().squeeze()
        if len(maskindices.shape) == 0:
            continue
        transforms = world_to_grids[v].unsqueeze(1)
        transforms = transforms.expand(batch_size, num_images, 4, 4).contiguous().view(-1, 4, 4).cuda()
        data_util.load_frames_multi(opt.data_path_2d, frames[v], depth_images, color_images, camera_poses, color_mean, color_std)

        # compute projection mapping
        proj_mapping = [projection.compute_projection(d, c, t) for d, c, t in zip(depth_images, camera_poses, transforms)]
        for d, c, t in zip(depth_images, camera_poses, transforms):
            test = projection.compute_projection(d, c, t)

        if None in proj_mapping: #invalid sample
            #print '(invalid sample)'
            continue
        proj_mapping = list(zip(*proj_mapping))
        proj_ind_3d = torch.stack(proj_mapping[0])
        proj_ind_2d = torch.stack(proj_mapping[1])

        if opt.use_proxy_loss:
            data_util.load_label_frames(opt.data_path_2d, frames[v], label_images, num_classes)
            mask2d = label_images.view(-1).clone()
            for k in range(num_classes):
                if criterion_weights[k] == 0:
                    mask2d[mask2d.eq(k)] = 0
            mask2d = mask2d.nonzero().squeeze()
            if (len(mask2d.shape) == 0):
                continue  # nothing to optimize for here
        # 2d
        imageft_fixed = model2d_fixed(torch.autograd.Variable(color_images))
        imageft = model2d_trainable(imageft_fixed)
        if opt.use_proxy_loss:
            ft2d = model2d_classifier(imageft)
            ft2d = ft2d.permute(0, 2, 3, 1).contiguous()

        # 2d/3d
        input3d = torch.autograd.Variable(volumes[v].cuda())
        output = model(input3d, imageft, torch.autograd.Variable(proj_ind_3d), torch.autograd.Variable(proj_ind_2d), grid_dims)

        loss = criterion(output.view(-1, num_classes), targets.view(-1))
        train_loss.append(loss.item())
        optimizer.zero_grad()
        optimizer2d.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer2d.step()
        if opt.use_proxy_loss:
            loss2d = criterion2d(ft2d.view(-1, num_classes), torch.autograd.Variable(label_images.view(-1)))
            train_loss_2d.append(loss2d.item())
            optimizer2d.zero_grad()
            optimizer2dc.zero_grad()
            loss2d.backward()
            optimizer2dc.step()
            optimizer2d.step()
            # confusion
            y = ft2d.data
            y = y.view(-1, num_classes)[:, :-1]
            _, predictions = y.max(1)
            predictions = predictions.view(-1)
            k = label_images.view(-1)
            confusion2d.add(torch.index_select(predictions, 0, mask2d), torch.index_select(k, 0, mask2d))

        # confusion
        y = output.data
        y = y.view(y.nelement()/y.size(2), num_classes)[:, :-1]
        _, predictions = y.max(1)
        predictions = predictions.view(-1)
        k = targets.data.view(-1)
        confusion.add(torch.index_select(predictions, 0, maskindices), torch.index_select(k, 0, maskindices))
        log_file.write(_SPLITTER.join([str(f) for f in [epoch, iter, loss.item()]]) + '\n')
        iter += 1
        if iter % 10000 == 0:
            torch.save(model.state_dict(), os.path.join(opt.output, 'model-iter%s-epoch%s.pth' % (iter, epoch)))
            torch.save(model2d_trainable.state_dict(), os.path.join(opt.output, 'model2d-iter%s-epoch%s.pth' % (iter, epoch)))
            if opt.use_proxy_loss:
                torch.save(model2d_classifier.state_dict(), os.path.join(opt.output, 'model2dc-iter%s-epoch%s.pth' % (iter, epoch)))
        if iter == 1:
            torch.save(model2d_fixed.state_dict(), os.path.join(opt.output, 'model2dfixed.pth'))

        if iter % 100 == 0:
            evaluate_confusion(confusion, train_loss, epoch, iter, -1, 'Train', log_file)
            if opt.use_proxy_loss:
                evaluate_confusion(confusion2d, train_loss_2d, epoch, iter, -1, 'Train2d', log_file_2d)

    end = time.time()
    took = end - start
    evaluate_confusion(confusion, train_loss, epoch, iter, took, 'Train', log_file)
    if opt.use_proxy_loss:
        evaluate_confusion(confusion2d, train_loss_2d, epoch, iter, took, 'Train2d', log_file_2d)
    return train_loss, iter, train_loss_2d