Beispiel #1
0
def main():
    # Testing settings
    DATASET_NAME= ['BIPED','BSDS','BSDS300','CID','DCD','MULTICUE',
                    'PASCAL','NYUD','CLASSIC'] # 8
    TEST_DATA = DATASET_NAME[5]
    data_inf = dataset_info(TEST_DATA)
    # training settings
    parser = argparse.ArgumentParser(description='Training application.')
    # Data parameters
    parser.add_argument('--input-dir', type=str,default='/opt/dataset/BIPED/edges',
                        help='the path to the directory with the input data.')
    parser.add_argument('--input-val-dir', type=str,default=data_inf['data_dir'],
                        help='the path to the directory with the input data for validation.')
    parser.add_argument('--output_dir', type=str, default='checkpoints',
                        help='the path to output the results.')
    parser.add_argument('--test_data', type=str, default=TEST_DATA,
                        help='Name of the dataset.')
    parser.add_argument('--test_list', type=str, default=data_inf['file_name'],
                        help='Name of the dataset.')
    parser.add_argument('--is_testing', type=bool, default=True,
                        help='Just for testing') # just for testing True
    parser.add_argument('--use_prev_trained', type=bool, default=True,
                        help='use previous trained data') # Just for test
    parser.add_argument('--checkpoint_data', type=str, default='24/24_model.pth',
                        help='Just for testing') #  '19/19_*.pht'
    parser.add_argument('--test_im_width', type=int, default=data_inf['img_width'],
                        help='image height for testing')
    parser.add_argument('--test_im_height', type=int, default=data_inf['img_height'],
                        help=' image height for testing')
    parser.add_argument('--res_dir', type=str, default='result',
                        help='Result directory')
    parser.add_argument('--log-interval-vis', type=int, default=50,
                        help='how many batches to wait before logging training status')
    # Optimization parameters
    parser.add_argument('--optimizer', type=str, choices=['adam', 'sgd'], default='adam',
                        help='the optimization solver to use (default: adam)')
    parser.add_argument('--num-epochs', type=int, default=25, metavar='N',
                        help='number of training epochs (default: 100)')
    # parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
    #                     help='learning rate (default: 1e-3)')
    parser.add_argument('--wd', type=float, default=1e-5, metavar='WD',
                        help='weight decay (default: 1e-5)')
    parser.add_argument('--lr', default=1e-4, type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_stepsize', default=1e4, type=int,
                        help='Learning rate step size.')
    parser.add_argument('--batch-size', type=int, default=8, metavar='B',
                        help='the mini-batch size (default: 2)')
    parser.add_argument('--num-workers', default=8, type=int,
                        help='the number of workers for the dataloader.')
    parser.add_argument('--tensorboard', action='store_true', default=True,
                        help='use tensorboard for logging purposes'),
    parser.add_argument('--gpu', type=str, default='1',
                        help='select GPU'),
    parser.add_argument('--img_width', type = int, default = 400, help='image size for training')
    parser.add_argument('--img_height', type = int, default = 400, help='image size for training')
    parser.add_argument('--channel_swap', default=[2, 1, 0], type=int)
    parser.add_argument('--crop_img', default=False, type=bool,
                        help='If true crop training images, other ways resizing')
    parser.add_argument('--mean_pixel_values', default=[104.00699, 116.66877, 122.67892, 137.86],
                        type=float)  # [103.939,116.779,123.68] [104.00699, 116.66877, 122.67892]
    args = parser.parse_args()

    tb_writer = None
    if args.tensorboard and not args.is_testing:
        from tensorboardX import SummaryWriter # previous torch version
        # from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather
        tb_writer = SummaryWriter(log_dir=args.output_dir)
    print(" **** You have available ", torch.cuda.device_count(), "GPUs!")
    print("Pytorch version: ", torch.__version__)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device = torch.device('cpu' if torch.cuda.device_count() == 0 else 'cuda')
    model = DexiNet().to(device)
    # model = nn.DataParallel(model)
    model.apply(weight_init)

    if not args.is_testing:

        dataset_train = BipedMyDataset(args.input_dir, train_mode='train',
                                      arg=args)

        dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,
                                      shuffle=True, num_workers=args.num_workers)
    dataset_val = testDataset(args.input_val_dir, arg=args)
    dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size,
                                shuffle=False, num_workers=args.num_workers)
    # for testing
    if args.is_testing:
        model.load_state_dict(torch.load(os.path.join(args.output_dir,args.checkpoint_data), map_location=device))

        model.eval()

        output_dir = os.path.join(args.res_dir, "BIPED2" + args.test_data)
        with torch.no_grad():
            for batch_id, sample_batched in enumerate(dataloader_val):
                images = sample_batched['images'].to(device)
                if not args.test_data == "CLASSIC":
                    labels = sample_batched['labels'].to(device)
                file_names = sample_batched['file_names']
                image_shape = sample_batched['image_shape']
                print("input image size: ",images.shape)
                output = model(images)
                save_image_batch_to_disk(output, output_dir, file_names,image_shape, arg=args)

        print("Testing ended in ",args.test_data, "dataset")
        sys.exit()

    criterion = weighted_cross_entropy_loss
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

    # Learning rate scheduler.
    # lr_schd = lr_scheduler.StepLR(optimizer, step_size=args.lr_stepsize,
    #                               gamma=args.lr_gamma)

    for epoch in range(args.num_epochs):
        # Create output directory
        output_dir_epoch = os.path.join(args.output_dir, str(epoch))
        img_test_dir = os.path.join(output_dir_epoch,args.test_data+'_res')
        create_directory(output_dir_epoch)
        create_directory(img_test_dir)
        # with torch.no_grad():
        #     validation(epoch, dataloader_val, model, device, img_test_dir,arg=args)
        train(epoch, dataloader_train, model, criterion, optimizer, device,
              args.log_interval_vis, tb_writer, args=args)

        # lr_schd.step() # decay lr at the end of the epoch.
    
        with torch.no_grad():
            validation(epoch, dataloader_val, model, device, img_test_dir,arg=args)

        try:
            net_state_dict = model.module.state_dict()
        except:
            net_state_dict = model.state_dict()

        torch.save(net_state_dict, os.path.join(
                   output_dir_epoch, '{0}_model.pth'.format(epoch)))
import random # import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torchvision.transforms import transforms

from model import DexiNet
import os

device = torch.device('cpu')
model = DexiNet().to(device)
model.load_state_dict(torch.load(os.path.join('checkpoints','24/24_model.pth'), map_location=device))
model.eval()


from torch.autograd import Variable
import torch
import onnx

# An example input you would normally provide to your model's forward() method.
input = torch.ones(1, 3, 400, 400)
#print(type(input[..., :3, :400, :400]))
raw_output = model(input)

print(raw_output.shape)

torch.onnx.export(model, input, '24_model.onnx', verbose=False, export_params=True)

print("-------------------------check model---------------------------------------\n")