コード例 #1
0
ファイル: train.py プロジェクト: sushantmakadia/Pix2Vox-1
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data augmentation
    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(
            cfg.TRAIN.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS,
                                          cfg.TRAIN.CONTRAST,
                                          cfg.TRAIN.SATURATION),
        utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.RandomFlip(),
        utils.data_transforms.RandomPermuteRGB(),
        utils.data_transforms.ToTensor(),
    ])
    val_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up data loader
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING,
            train_transforms),
        batch_size=cfg.CONST.BATCH_SIZE,
        num_workers=cfg.TRAIN.NUM_WORKER,
        pin_memory=True,
        shuffle=True,
        drop_last=True)
    val_data_loader = torch.utils.data.DataLoader(
        dataset=val_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING,
            val_transforms),
        batch_size=1,
        num_workers=1,
        pin_memory=True,
        shuffle=False)

    # Set up networks
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)
    print('[DEBUG] %s Parameters in Encoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(encoder)))
    print('[DEBUG] %s Parameters in Decoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(decoder)))
    print('[DEBUG] %s Parameters in Refiner: %d.' %
          (dt.now(), utils.network_utils.count_parameters(refiner)))
    print('[DEBUG] %s Parameters in Merger: %d.' %
          (dt.now(), utils.network_utils.count_parameters(merger)))

    # Initialize weights of networks
    encoder.apply(utils.network_utils.init_weights)
    decoder.apply(utils.network_utils.init_weights)
    refiner.apply(utils.network_utils.init_weights)
    merger.apply(utils.network_utils.init_weights)

    # Set up solver
    if cfg.TRAIN.POLICY == 'adam':
        encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                 encoder.parameters()),
                                          lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        decoder_solver = torch.optim.Adam(decoder.parameters(),
                                          lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        refiner_solver = torch.optim.Adam(refiner.parameters(),
                                          lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        merger_solver = torch.optim.Adam(merger.parameters(),
                                         lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                         betas=cfg.TRAIN.BETAS)
    elif cfg.TRAIN.POLICY == 'sgd':
        encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                encoder.parameters()),
                                         lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        decoder_solver = torch.optim.SGD(decoder.parameters(),
                                         lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        refiner_solver = torch.optim.SGD(refiner.parameters(),
                                         lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        merger_solver = torch.optim.SGD(merger.parameters(),
                                        lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                        momentum=cfg.TRAIN.MOMENTUM)
    else:
        raise Exception('[FATAL] %s Unknown optimizer %s.' %
                        (dt.now(), cfg.TRAIN.POLICY))

    # Set up learning rate scheduler to decay learning rates dynamically
    encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        encoder_solver,
        milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        decoder_solver,
        milestones=cfg.TRAIN.DECODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        refiner_solver,
        milestones=cfg.TRAIN.REFINER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        merger_solver,
        milestones=cfg.TRAIN.MERGER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    if torch.cuda.is_available():
        encoder = torch.nn.DataParallel(encoder).cuda()
        decoder = torch.nn.DataParallel(decoder).cuda()
        refiner = torch.nn.DataParallel(refiner).cuda()
        merger = torch.nn.DataParallel(merger).cuda()

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Load pretrained model if exists
    init_epoch = 0
    best_iou = -1
    best_epoch = -1
    if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:
        print('[INFO] %s Recovering from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        init_epoch = checkpoint['epoch_idx']
        best_iou = checkpoint['best_iou']
        best_epoch = checkpoint['best_epoch']

        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

        print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' \
                 % (dt.now(), init_epoch, best_iou, best_epoch))

    # Summary writer for TensorBoard
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())
    log_dir = output_dir % 'logs'
    ckpt_dir = output_dir % 'checkpoints'
    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    val_writer = SummaryWriter(os.path.join(log_dir, 'test'))

    # Training loop
    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
        # Tick / tock
        epoch_start_time = time()

        # Batch average meterics
        batch_time = utils.network_utils.AverageMeter()
        data_time = utils.network_utils.AverageMeter()
        encoder_losses = utils.network_utils.AverageMeter()
        refiner_losses = utils.network_utils.AverageMeter()

        # Adjust learning rate
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()
        refiner_lr_scheduler.step()
        merger_lr_scheduler.step()

        # switch models to training mode
        encoder.train()
        decoder.train()
        merger.train()
        refiner.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (taxonomy_names, sample_names, rendering_images,
                        ground_truth_volumes) in enumerate(train_data_loader):
            # Measure data time
            data_time.update(time() - batch_end_time)

            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            ground_truth_volumes = utils.network_utils.var_or_cuda(
                ground_truth_volumes)

            # Train the encoder, decoder, refiner, and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volumes = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volumes = merger(raw_features, generated_volumes)
            else:
                generated_volumes = torch.mean(generated_volumes, dim=1)
            encoder_loss = bce_loss(generated_volumes,
                                    ground_truth_volumes) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volumes = refiner(generated_volumes)
                refiner_loss = bce_loss(generated_volumes,
                                        ground_truth_volumes) * 10
            else:
                refiner_loss = encoder_loss

            # Gradient decent
            encoder.zero_grad()
            decoder.zero_grad()
            refiner.zero_grad()
            merger.zero_grad()

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                encoder_loss.backward(retain_graph=True)
                refiner_loss.backward()
            else:
                encoder_loss.backward()

            encoder_solver.step()
            decoder_solver.step()
            refiner_solver.step()
            merger_solver.step()

            # Append loss to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())
            # Append loss to TensorBoard
            n_itr = epoch_idx * n_batches + batch_idx
            train_writer.add_scalar('EncoderDecoder/BatchLoss',
                                    encoder_loss.item(), n_itr)
            train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(),
                                    n_itr)

            # Tick / tock
            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            print('[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % \
                (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, \
                    batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item()))

        # Append epoch loss to TensorBoard
        train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg,
                                epoch_idx + 1)
        train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg,
                                epoch_idx + 1)

        # Tick / tock
        epoch_end_time = time()
        print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' %
            (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, \
                encoder_losses.avg, refiner_losses.avg))

        # Update Rendering Views
        if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:
            n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)
            train_data_loader.dataset.set_n_views_rendering(n_views_rendering)
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % \
                (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering))

        # Validate the training models
        iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader,
                       val_writer, encoder, decoder, refiner, merger)

        # Save weights to file
        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)
        if iou > best_iou:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            best_iou = iou
            best_epoch = epoch_idx + 1
            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'best-ckpt.pth'), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)

    # Close SummaryWriter for TensorBoard
    train_writer.close()
    val_writer.close()
コード例 #2
0
class Model(pl.LightningModule):

    def __init__(self, cfg_network: DictConfig, cfg_tester: DictConfig):
        super().__init__()
        self.cfg_network = cfg_network
        self.cfg_tester = cfg_tester

        # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
        torch.backends.cudnn.benchmark = True

        # Set up networks
        self.encoder = Encoder(cfg_network)
        self.decoder = Decoder(cfg_network)
        self.refiner = Refiner(cfg_network)
        self.merger = Merger(cfg_network)
        
        # Initialize weights of networks
        self.encoder.apply(utils.network_utils.init_weights)
        self.decoder.apply(utils.network_utils.init_weights)
        self.refiner.apply(utils.network_utils.init_weights)
        self.merger.apply(utils.network_utils.init_weights)
        
        self.bce_loss = nn.BCELoss()

    def configure_optimizers(self):
        params = self.cfg_network.optimization
        # Set up solver
        if params.policy == 'adam':
            encoder_solver = optim.Adam(filter(lambda p: p.requires_grad, self.encoder.parameters()),
                                            lr=params.encoder_lr,
                                            betas=params.betas)
            decoder_solver = optim.Adam(self.decoder.parameters(),
                                            lr=params.decoder_lr,
                                            betas=params.betas)
            refiner_solver = optim.Adam(self.refiner.parameters(),
                                            lr=params.refiner_lr,
                                            betas=params.betas)
            merger_solver = optim.Adam(self.merger.parameters(),
                                             lr=params.merger_lr,
                                             betas=params.betas)
        elif params.policy == 'sgd':
            encoder_solver = optim.SGD(filter(lambda p: p.requires_grad, self.encoder.parameters()),
                                            lr=params.encoder_lr,
                                            momentum=params.momentum)
            decoder_solver = optim.SGD(self.decoder.parameters(),
                                            lr=params.decoder_lr,
                                            momentum=params.momentum)
            refiner_solver = optim.SGD(self.refiner.parameters(),
                                            lr=params.refiner_lr,
                                            momentum=params.momentum)
            merger_solver = optim.SGD(self.merger.parameters(),
                                            lr=params.merger_lr,
                                            momentum=params.momentum)
        else:
            raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), params.policy))
            
            # Set up learning rate scheduler to decay learning rates dynamically
        encoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(encoder_solver,
                                                                    milestones=params.encoder_lr_milestones,
                                                                    gamma=params.gamma)
        decoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(decoder_solver,
                                                                    milestones=params.decoder_lr_milestones,
                                                                    gamma=params.gamma)
        refiner_lr_scheduler = optim.lr_scheduler.MultiStepLR(refiner_solver,
                                                                    milestones=params.refiner_lr_milestones,
                                                                    gamma=params.gamma)
        merger_lr_scheduler = optim.lr_scheduler.MultiStepLR(merger_solver,
                                                                milestones=params.merger_lr_milestones,
                                                                gamma=params.gamma)
        
        return [encoder_solver, decoder_solver, refiner_solver, merger_solver], \
               [encoder_lr_scheduler, decoder_lr_scheduler, refiner_lr_scheduler, merger_lr_scheduler]
    
    def _fwd(self, batch):
        taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch

        image_features = self.encoder(rendering_images)
        raw_features, generated_volumes = self.decoder(image_features)

        if self.cfg_network.use_merger and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_merger:
            generated_volumes = self.merger(raw_features, generated_volumes)
        else:
            generated_volumes = torch.mean(generated_volumes, dim=1)
        encoder_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10
        
        if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner:
            generated_volumes = self.refiner(generated_volumes)
            refiner_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10
        else:
            refiner_loss = encoder_loss
        
        return generated_volumes, encoder_loss, refiner_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        (opt_enc, opt_dec, opt_ref, opt_merg) = self.optimizers()
        
        generated_volumes, encoder_loss, refiner_loss = self._fwd(batch)
        
        self.log('loss/EncoderDecoder', encoder_loss, 
                 prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log('loss/Refiner', refiner_loss, 
                 prog_bar=True, logger=True, on_step=True, on_epoch=True)

        if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner:
            self.manual_backward(encoder_loss, opt_enc, retain_graph=True)
            self.manual_backward(refiner_loss, opt_ref)
        else:
            self.manual_backward(encoder_loss, opt_enc)
            
        for opt in self.optimizers():
            opt.step()
            opt.zero_grad()

    def training_epoch_end(self, outputs) -> None:
        # Update Rendering Views
        if self.cfg_network.update_n_views_rendering:
            n_views_rendering = self.trainer.datamodule.update_n_views_rendering()
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' %
                  (dt.now(), self.current_epoch + 2, self.trainer.max_epochs, n_views_rendering))

    def _eval_step(self, batch, batch_idx):
        # SUPPORTS ONLY BATCH_SIZE=1
        taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch
        taxonomy_id = taxonomy_names[0]
        sample_name = sample_names[0]

        generated_volumes, encoder_loss, refiner_loss = self._fwd(batch)

        self.log('val_loss/EncoderDecoder', encoder_loss, prog_bar=True,
                 logger=True, on_step=True, on_epoch=True)

        self.log('val_loss/Refiner', refiner_loss, prog_bar=True,
                 logger=True, on_step=True, on_epoch=True)

        # IoU per sample
        sample_iou = []
        for th in self.cfg_tester.voxel_thresh:
            _volume = torch.ge(generated_volumes, th).float()
            intersection = torch.sum(_volume.mul(ground_truth_volumes)).float()
            union = torch.sum(
                torch.ge(_volume.add(ground_truth_volumes), 1)).float()
            sample_iou.append((intersection / union).item())

        # Print sample loss and IoU
        n_samples = -1
        print('\n[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' %
              (dt.now(), batch_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(),
               refiner_loss.item(), ['%.4f' % si for si in sample_iou]))

        return {
            'taxonomy_id': taxonomy_id,
            'sample_name': sample_name,
            'sample_iou': sample_iou
        }
        
    def _eval_epoch_end(self, outputs):
        # Load taxonomies of dataset
        taxonomies = []
        taxonomy_path = self.trainer.datamodule.get_test_taxonomy_file_path()
        with open(taxonomy_path, encoding='utf-8') as file:
            taxonomies = json.loads(file.read())
        taxonomies = {t['taxonomy_id']: t for t in taxonomies}

        test_iou = {}
        for output in outputs:
            taxonomy_id, sample_name, sample_iou = output[
                'taxonomy_id'], output['sample_name'], output['sample_iou']
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

        mean_iou = []
        for taxonomy_id in test_iou:
            test_iou[taxonomy_id]['iou'] = torch.mean(
                torch.tensor(test_iou[taxonomy_id]['iou']), dim=0)
            mean_iou.append(test_iou[taxonomy_id]['iou']
                            * test_iou[taxonomy_id]['n_samples'])
        n_samples = len(outputs)
        mean_iou = torch.stack(mean_iou)
        mean_iou = torch.sum(mean_iou, dim=0) / n_samples

        # Print header
        print('============================ TEST RESULTS ============================')
        print('Taxonomy', end='\t')
        print('#Sample', end='\t')
        print(' Baseline', end='\t')
        for th in self.cfg_tester.voxel_thresh:
            print('t=%.2f' % th, end='\t')
        print()
        # Print body
        for taxonomy_id in test_iou:
            print('%s' % taxonomies[taxonomy_id]
                  ['taxonomy_name'].ljust(8), end='\t')
            print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
            if 'baseline' in taxonomies[taxonomy_id]:
                n_views_rendering = self.trainer.datamodule.get_n_views_rendering()
                print('%.4f' % taxonomies[taxonomy_id]['baseline']
                      ['%d-view' % n_views_rendering], end='\t\t')
            else:
                print('N/a', end='\t\t')

            for ti in test_iou[taxonomy_id]['iou']:
                print('%.4f' % ti, end='\t')
            print()
        # Print mean IoU for each threshold
        print('Overall ', end='\t\t\t\t')
        for mi in mean_iou:
            print('%.4f' % mi, end='\t')
        print('\n')

        max_iou = torch.max(mean_iou)
        self.log('Refiner/IoU', max_iou, prog_bar=True, on_epoch=True)
    
    def validation_step(self, batch, batch_idx):
        return self._eval_step(batch, batch_idx)
        
    def validation_epoch_end(self, outputs):
        self._eval_epoch_end(outputs)
        
    def test_step(self, batch, batch_idx):
        return self._eval_step(batch, batch_idx)
        
    def test_epoch_end(self, outputs):
        self._eval_epoch_end(outputs)

    def get_progress_bar_dict(self):
        # don't show the loss as it's None
        items = super().get_progress_bar_dict()
        items.pop("loss", None)
        return items