Example #1
0
def main():
    global args, best_result, output_directory

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
        train_loader = NYUDepth_loader(args.data_path,
                                       batch_size=args.batch_size,
                                       isTrain=True)
        val_loader = NYUDepth_loader(args.data_path,
                                     batch_size=args.batch_size,
                                     isTrain=False)
    else:
        train_loader = NYUDepth_loader(args.data_path,
                                       batch_size=args.batch_size,
                                       isTrain=True)
        val_loader = NYUDepth_loader(args.data_path, isTrain=False)

    if args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)

        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']

        model_dict = checkpoint['model'].module.state_dict(
        )  # to load the trained model using multi-GPUs
        model = DORN_nyu.DORN()
        model.load_state_dict(model_dict)

        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        del checkpoint  # clear memory
        del model_dict
    else:
        print("=> creating Model")
        model = DORN_nyu.DORN()
        print("=> model created.")
        start_epoch = 0

    # in paper, aspp module's lr is 20 bigger than the other modules
    train_params = [{
        'params': model.feature_extractor.parameters(),
        'lr': args.lr
    }, {
        'params': model.aspp_module.parameters(),
        'lr': args.lr * 20
    }, {
        'params': model.orl.parameters(),
        'lr': args.lr
    }]

    optimizer = torch.optim.SGD(train_params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # You can use DataParallel() whether you use Multi-GPUs or not
    model = nn.DataParallel(model)
    model = model.cuda()

    # when training, use reduceLROnPlateau to reduce learning rate
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=args.lr_patience)

    # loss function
    criterion = criteria.ordLoss()

    # create directory path
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')
    config_txt = os.path.join(output_directory, 'config.txt')

    # write training parameters to config file
    if not os.path.exists(config_txt):
        with open(config_txt, 'w') as txtfile:
            args_ = vars(args)
            args_str = ''
            for k, v in args_.items():
                args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
            txtfile.write(args_str)

    # create log
    log_path = os.path.join(
        output_directory, 'logs',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    for epoch in range(start_epoch, args.epochs):
        train(train_loader, model, criterion, optimizer, epoch,
              logger)  # train for one epoch
        result, img_merge = validate(val_loader, model, epoch,
                                     logger)  # evaluate on validation set

        for i, param_group in enumerate(optimizer.param_groups):
            old_lr = float(param_group['lr'])

            logger.add_scalar('Lr/lr_' + str(i), old_lr, epoch)

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}, rmse={:.3f}, rml={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, "
                    "t_gpu={:.4f}".format(epoch, result.rmse, result.absrel,
                                          result.lg10, result.delta1,
                                          result.delta2, result.delta3,
                                          result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        # save checkpoint for each epoch
        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

        # when rml doesn't fall, reduce learning rate
        scheduler.step(result.absrel)

    logger.close()
def main():
    global args, best_result, output_directory

    if torch.cuda.device_count() > 1:
        args.batch_size = args.batch_size * torch.cuda.device_count()
        train_loader = NYUDepth_loader(args.data_path,
                                       batch_size=args.batch_size,
                                       isTrain=True)
        val_loader = NYUDepth_loader(args.data_path,
                                     batch_size=args.batch_size,
                                     isTrain=False)
    else:
        train_loader = NYUDepth_loader(args.data_path,
                                       batch_size=args.batch_size,
                                       isTrain=True)
        val_loader = NYUDepth_loader(args.data_path, isTrain=False)

    if args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        # args = checkpoint['args']
        # print('保留参数:', args)
        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        if torch.cuda.device_count() > 1:
            model_dict = checkpoint['model'].module.state_dict(
            )  # 如果是多卡训练的要加module
        else:
            model_dict = checkpoint['model'].state_dict()
        model = DORN_nyu.DORN()
        model.load_state_dict(model_dict)
        del model_dict  # 删除载入的模型
        # 使用SGD进行优化
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)

        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> creating Model")
        model = DORN_nyu.DORN()
        print("=> model created.")
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
        start_epoch = 0
    # 如果有多GPU 使用多GPU训练
    if torch.cuda.device_count():
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.cuda()

    # 定义loss函数
    criterion = criteria.ordLoss()

    # 创建保存结果目录文件
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')

    log_path = os.path.join(
        output_directory, 'logs',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    for epoch in range(start_epoch, args.epochs):
        # lr = utils.adjust_learning_rate(optimizer, args.lr, epoch)  # 更新学习率

        train(train_loader, model, criterion, optimizer, epoch,
              logger)  # train for one epoch
        result, img_merge = validate(val_loader, model, epoch,
                                     logger)  # evaluate on validation set

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}\nrmse={:.3f}\nrml={:.3f}\nlog10={:.3f}\nd1={:.3f}\nd2={:.3f}\ndd31={:.3f}\nt_gpu={:.4f}\n"
                    .format(epoch, result.rmse, result.absrel, result.lg10,
                            result.delta1, result.delta2, result.delta3,
                            result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        # 每个Epoch都保存解雇
        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

    logger.close()
Example #3
0
def main():
    global args, best_result, output_directory
    args.dataset = 'hacker'
    args.batch_size = 1

    # set random seed
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use GPU ", torch.cuda.current_device())

    train_loader, val_loader = create_loader(args)

    if args.resume:  # default false
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)

        start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        optimizer = checkpoint['optimizer']

        # solve 'out of memory'
        model = checkpoint['model']

        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

        # clear memory
        del checkpoint
        # del model_dict
        torch.cuda.empty_cache()
    else:
        print("=> creating Model")
        model = get_models(args.dataset, pretrained=True, freeze=True)
        print("=> model created.")
        start_epoch = 0

        # different modules have different learning rate
        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        optimizer = torch.optim.SGD(train_params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        # You can use DataParallel() whether you use Multi-GPUs or not
        model = nn.DataParallel(model).cuda()

    # when training, use reduceLROnPlateau to reduce learning rate
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=args.lr_patience)

    # loss function
    criterion = criteria.ordLoss()

    # create directory path
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')
    config_txt = os.path.join(output_directory, 'config.txt')

    # write training parameters to config file
    if not os.path.exists(config_txt):
        with open(config_txt, 'w') as txtfile:
            args_ = vars(args)
            args_str = ''
            for k, v in args_.items():
                args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
            txtfile.write(args_str)

    # create log
    log_path = os.path.join(
        output_directory, 'logs',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    if os.path.isdir(log_path):
        shutil.rmtree(log_path)
    os.makedirs(log_path)
    logger = SummaryWriter(log_path)

    for epoch in range(start_epoch, args.epochs):

        # remember change of the learning rate
        for i, param_group in enumerate(optimizer.param_groups):
            old_lr = float(param_group['lr'])
            logger.add_scalar('Lr/lr_' + str(i), old_lr, epoch)

        train(train_loader, model, criterion, optimizer, epoch,
              logger)  # train for one epoch
        result, img_merge = validate(val_loader, model, epoch,
                                     logger)  # evaluate on validation set

        # remember best rmse and save checkpoint
        is_best = result.rmse < best_result.rmse
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write(
                    "epoch={}, rmse={:.3f}, rml={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, "
                    "t_gpu={:.4f}".format(epoch, result.rmse, result.absrel,
                                          result.lg10, result.delta1,
                                          result.delta2, result.delta3,
                                          result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        # save checkpoint for each epoch
        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

        # when rml doesn't fall, reduce learning rate
        scheduler.step(result.absrel)

    logger.close()