示例#1
0
def main():
    modality_names = CustomDataLoader.modality_names
    # Arguments
    parser = argparse.ArgumentParser(
        description=
        'High Quality Monocular Depth Estimation via Transfer Learning')
    parser.add_argument('--epochs',
                        default=30,
                        type=int,
                        help='number of total epochs to run')
    parser.add_argument('--lr',
                        '--learning-rate',
                        default=0.0001,
                        type=float,
                        help='initial learning rate')
    parser.add_argument('--batch_size', default=1, type=int, help='batch size')
    parser.add_argument('--path', default='../data/nyudepthv2', help='path')
    parser.add_argument('--data', default='nyudepthv2', help='model')
    parser.add_argument('--modality',
                        '-m',
                        metavar='MODALITY',
                        default='rgb',
                        choices=modality_names,
                        help='modality: ' + ' | '.join(modality_names) +
                        ' (default: rgb)')
    parser.add_argument('-j',
                        '--workers',
                        default=16,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 16)')
    args = parser.parse_args()

    # Create model
    torch.cuda.empty_cache()
    model = FullModel().cuda()
    print('Model created.')

    # Training parameters
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    batch_size = args.batch_size
    prefix = 'densenet_' + str(batch_size)

    # Load data
    train_loader, test_loader = createDataLoaders(args)

    torch.cuda.empty_cache()
    # Logging
    writer = SummaryWriter(comment='{}-lr{}-e{}-batch_size{}'.format(
        prefix, args.lr, args.epochs, args.batch_size),
                           flush_secs=30)

    # Loss
    l1_criterion = nn.L1Loss()

    # Start training...
    for epoch in range(args.epochs):
        batch_time = AverageMeter()
        losses = AverageMeter()
        N = len(train_loader)

        # Switch to train mode
        model.train()

        end = time.time()

        for i, (image, depth) in enumerate(train_loader):
            optimizer.zero_grad()
            # Prepare sample and target
            image, depth = image.cuda(), depth.cuda()
            torch.cuda.synchronize()
            data_time = time.time() - end
            # Normalize depth
            depth_n = DepthNorm(depth)
            # Predict
            output = model(image)
            # Compute the loss
            criterion = criteria.MaskedL1Loss().cuda()
            loss = criterion.forward(output, depth_n)
            # l_ssim = torch.clamp((1 - ssim(output, depth_n, val_range=1000.0 / 10.0)) * 0.5, 0, 1)

            # loss = (1.0 * l_ssim) + (0.1 * loss)

            # Update step
            # losses.update(loss.data.item(), image.size(0))
            loss.backward()
            optimizer.step()
            torch.cuda.synchronize()
            gpu_time = time.time() - end
            result = Result()
            result.evaluate(output.data, depth_n.data)
            # Measure elapsed time
            batch_time.update(result, gpu_time, data_time, image.size(0))
            # end = time.time()
            eta = str(time.time() - end)
            end = time.time()
            # Log progress
            niter = epoch * N + i

            if i % 5 == 0:
                # Print to console
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.sum_data_time:.3f} ({batch_time.sum_gpu_time:.3f})\t'
                    'ETA {eta}\t'
                    'Loss {loss} RMSE {rmse}'.format(epoch,
                                                     i,
                                                     N,
                                                     batch_time=batch_time,
                                                     loss=loss,
                                                     rmse=result.rmse,
                                                     eta=eta))

                # Log to tensorboard
                writer.add_scalar('Train/Loss', loss, niter)

            if i % 300 == 0:
                LogProgress(model, writer, test_loader, niter)

        # Record epoch's intermediate results
        logEpoch(epoch, loss, "TrainOutput.txt")
        LogProgress(model, writer, test_loader, niter)
        writer.add_scalar('Train/Loss.avg', loss, epoch)
    # save the final model
    base_dir = "TrainedModel"
    entire_model_dir = os.path.join(base_dir, "EntireModel")
    model_param_dir = os.path.join(base_dir, "ModelParameters")
    if not os.path.exists("TrainedModel"):
        os.mkdir(base_dir)
        if not os.path.exists(os.path.join(base_dir, "EntireModel")):
            os.mkdir(entire_model_dir)
        if not os.path.exists(os.path.join(base_dir, "ModelParameters")):
            os.mkdir(model_param_dir)
    torch_model_name = "model_batch_{:}_epochs_{:}.pt".format(
        args.batch_size, args.epochs)
    torch.save(model, os.path.join(entire_model_dir, torch_model_name))
    torch.save(model.state_dict(),
               os.path.join(model_param_dir, torch_model_name))
    print('done')