示例#1
0
def main(config):

    # device
    device = torch.device("cuda")
    initial_epoch = 80
    num_epochs = 100

    outdir = '/usr/local/data/raghav/ECSE626_2019/Project/Experiments/'  # Full Path to Directory where to store all generated files: Ex. "/usr/local/data/raghav/MSLAQ_experiments/Experiments"
    main_path = '/usr/local/data/raghav/ECSE626_2019/Project/data/'  # Full Path of Input HDf5 file: Ex. "/usr/local/data/raghav/MSLAQ_loader/MSLAQ.hdf5"
    ConfigName = 'Seg_Depth_MTWL'  # Configuration Name to Uniquely Identify this Experiment

    #########################################################################################################

    log_path = join(outdir, ConfigName, 'log')

    os.makedirs(log_path, exist_ok=True)
    os.makedirs(join(log_path, 'weights'), exist_ok=True)
    os.makedirs(join(log_path, 'visualize'), exist_ok=True)

    ##################################################################################################

    #####################################################################################################

    model = DeepLab(backbone='resnet',
                    output_stride=8,
                    num_classes=[20, 1],
                    sync_bn=False,
                    freeze_bn=False)

    print("===> Model Defined.")

    #############################################################################################

    reg_task_weight = torch.zeros(1, requires_grad=True, device=device)
    seg_task_weight = torch.zeros(1, requires_grad=True, device=device)

    print("===> Initial Reg Weight and Seg weight: {}  {}".format(
        torch.exp(-reg_task_weight), torch.exp(-seg_task_weight)))
    print("===> Initial Reg Weight and Seg weight Require Grad: {}  {}".format(
        reg_task_weight.requires_grad, seg_task_weight.requires_grad))

    params = ([p for p in model.parameters()] + [reg_task_weight] +
              [seg_task_weight])

    #############################################################################################

    optimizer = optim.SGD(params, lr=0.0025, momentum=0.9, nesterov=True)
    lambda2 = lambda epoch: (1 - epoch / num_epochs)**0.9
    scheduler = LambdaLR(optimizer, lr_lambda=lambda2)

    print("===> Optimizer Initialized")

    ########################################################################################

    if initial_epoch > 0:
        print("===> Loading pre-trained weight {}".format(initial_epoch))
        weight_path = 'weights/model-{:04d}.pt'.format(initial_epoch)
        # model = torch.load(join(log_path, weight_path))
        checkpoint = torch.load(join(log_path, weight_path))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        reg_task_weight = checkpoint['reg_task_weight']
        seg_task_weight = checkpoint['seg_task_weight']
        # epoch = checkpoint['epoch']

    #print(model)
    model = model.to(device)

    ##################################################################################################

    input_transform = transforms.Compose([
        transforms.Resize(size=128, interpolation=2),
        PIL_To_Tensor(),
        NormalizeRange(),
        transforms.Normalize(mean=CITYSCAPES_MEAN, std=CITYSCAPES_STD)
    ])
    label_transform = transforms.Compose([
        transforms.Resize(size=128, interpolation=0),
        PIL_To_Tensor(),
        Project(projection=CITYSCAPES_CLASSES_TO_LABELS)
    ])
    depth_transform = transforms.Compose([
        transforms.Resize(size=128, interpolation=2),
        PIL_To_Tensor(),
        DepthConversion()
    ])

    training_data_loader = torch.utils.data.DataLoader(Cityscapes(
        root='/usr/local/data/raghav/ECSE626_2019/Project/data/',
        split='train',
        mode='fine',
        target_type=['semantic', 'depth'],
        transform=input_transform,
        target_transform=[label_transform, depth_transform]),
                                                       batch_size=8,
                                                       shuffle=True,
                                                       num_workers=4,
                                                       drop_last=True)

    validation_data_loader = torch.utils.data.DataLoader(Cityscapes(
        root='/usr/local/data/raghav/ECSE626_2019/Project/data/',
        split='val',
        mode='fine',
        target_type=['semantic', 'depth'],
        transform=input_transform,
        target_transform=[label_transform, depth_transform]),
                                                         batch_size=8,
                                                         shuffle=True,
                                                         num_workers=4,
                                                         drop_last=True)

    print("===> Training and Validation Data Loaderss Initialized")

    ##########################################################################################################

    my_metric = ['meanIoU', "mse", "mae"]

    my_loss = [
        "loss", "loss_seg", "loss_reg", "reg_weight", "seg_weight",
        "reg_precision", "seg_precision"
    ]

    logger = Logger(mylog_path=log_path,
                    mylog_name="training.log",
                    myloss_names=my_loss,
                    mymetric_names=my_metric)
    LP = LossPlotter(mylog_path=log_path,
                     mylog_name="training.log",
                     myloss_names=my_loss,
                     mymetric_names=my_metric)

    print("===> Logger and LossPlotter Initialized")

    ############################################################################################

    def checkpoint(epoch):
        w_path = 'weights/model-{:04d}.pt'.format(epoch)
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'seg_task_weight': seg_task_weight,
                'reg_task_weight': reg_task_weight
            }, join(log_path, w_path))
        print("===> Checkpoint saved to {}".format(w_path))

    #################################################################################################

    #####################################################################################################

    def validate():

        model.eval()

        metric = np.zeros(len(my_metric) + len(my_loss))

        seg_crit = nn.CrossEntropyLoss(reduction='none')
        reg_crit = nn.L1Loss(reduction='none')

        with torch.no_grad():

            for iteration, batch in enumerate(validation_data_loader):

                inp, target, seg = batch[0].to(device), batch[1][1].to(
                    device), batch[1][0].type('torch.LongTensor').squeeze().to(
                        device)

                outp = model(inp)

                loss = 0

                seg_precision = torch.exp(-seg_task_weight)
                loss_seg = seg_crit(outp[0], seg)
                loss_seg = torch.mean(loss_seg)

                reg_precision = torch.exp(-reg_task_weight)
                loss_reg = reg_crit(outp[1], target)
                loss_reg = torch.mean(loss_reg)

                loss = loss_seg + loss_reg

                loss = loss.item()

                loss_seg = loss_seg.item()

                loss_reg = loss_reg.item()

                seg = np.squeeze(seg.data.cpu().numpy().astype('float32'))
                outp_seg = torch.argmax(outp[0], dim=1, keepdim=False)
                outp_seg = np.squeeze(
                    outp_seg.data.cpu().numpy().astype('float32'))

                mIoU = meanIoU(seg, outp_seg)

                mean_squared_error = torch.mean((outp[1] - target)**2).item()

                mean_absolute_error = torch.mean(torch.abs(outp[1] -
                                                           target)).item()

                metric += np.array([
                    loss, loss_seg, loss_reg,
                    reg_task_weight.item(),
                    seg_task_weight.item(),
                    reg_precision.item(),
                    seg_precision.item(), mIoU, mean_squared_error,
                    mean_absolute_error
                ])

                # if iteration==10:
                #     break

        return metric / len(validation_data_loader)
        # return metric/10

    #########################################################################################################

    total_params = sum(p.numel() for p in model.parameters())

    val_metric = validate()
    print(
        "===> Validation Epoch {}: Loss - {:.4f}, Loss_Seg - {:.4f}, Loss_Reg - {:.4f}, mean IoU - {:.4f}, MSE - {:.4f}, MAE - {:.4f}"
        .format(initial_epoch, val_metric[0], val_metric[1], val_metric[2],
                val_metric[7], val_metric[8], val_metric[9]))
示例#2
0
def main(config):

    # device
    device = torch.device("cuda")
    initial_epoch = 0
    num_epochs = 100

    outdir = '/usr/local/data/raghav/ECSE626_2019/Project/Experiments/'  # Full Path to Directory where to store all generated files: Ex. "/usr/local/data/raghav/MSLAQ_experiments/Experiments"
    main_path = '/usr/local/data/raghav/ECSE626_2019/Project/data/'  # Full Path of Input HDf5 file: Ex. "/usr/local/data/raghav/MSLAQ_loader/MSLAQ.hdf5"
    ConfigName = 'Depth'  # Configuration Name to Uniquely Identify this Experiment

    #########################################################################################################

    log_path = join(outdir, ConfigName, 'log')

    os.makedirs(log_path, exist_ok=True)
    os.makedirs(join(log_path, 'weights'), exist_ok=True)
    os.makedirs(join(log_path, 'visualize'), exist_ok=True)

    ##################################################################################################

    #####################################################################################################

    model = DeepLab(backbone='resnet',
                    output_stride=8,
                    num_classes=[1],
                    sync_bn=False,
                    freeze_bn=False)

    print("===> Model Defined.")

    #############################################################################################

    params = ([p for p in model.parameters()])

    #############################################################################################

    optimizer = optim.SGD(params, lr=0.0025, momentum=0.9, nesterov=True)
    lambda2 = lambda epoch: (1 - epoch / num_epochs)**0.9
    scheduler = LambdaLR(optimizer, lr_lambda=lambda2)

    print("===> Optimizer Initialized")

    ########################################################################################

    if initial_epoch > 0:
        print("===> Loading pre-trained weight {}".format(initial_epoch))
        weight_path = 'weights/model-{:04d}.pt'.format(initial_epoch)
        # model = torch.load(join(log_path, weight_path))
        checkpoint = torch.load(join(log_path, weight_path))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # epoch = checkpoint['epoch']

    #print(model)
    model = model.to(device)

    ##################################################################################################

    input_transform = transforms.Compose([
        transforms.Resize(size=128, interpolation=2),
        PIL_To_Tensor(),
        NormalizeRange(),
        transforms.Normalize(mean=CITYSCAPES_MEAN, std=CITYSCAPES_STD)
    ])
    label_transform = transforms.Compose([
        transforms.Resize(size=128, interpolation=0),
        PIL_To_Tensor(),
        Project(projection=CITYSCAPES_CLASSES_TO_LABELS)
    ])
    depth_transform = transforms.Compose([
        transforms.Resize(size=128, interpolation=2),
        PIL_To_Tensor(),
        DepthConversion()
    ])

    training_data_loader = torch.utils.data.DataLoader(Cityscapes(
        root='/usr/local/data/raghav/ECSE626_2019/Project/data/',
        split='train',
        mode='fine',
        target_type=['depth'],
        transform=input_transform,
        target_transform=[depth_transform]),
                                                       batch_size=8,
                                                       shuffle=True,
                                                       num_workers=4,
                                                       drop_last=True)

    validation_data_loader = torch.utils.data.DataLoader(Cityscapes(
        root='/usr/local/data/raghav/ECSE626_2019/Project/data/',
        split='val',
        mode='fine',
        target_type=['depth'],
        transform=input_transform,
        target_transform=[depth_transform]),
                                                         batch_size=8,
                                                         shuffle=True,
                                                         num_workers=4,
                                                         drop_last=True)

    print("===> Training and Validation Data Loaderss Initialized")

    ##########################################################################################################

    my_metric = ["mse", "mae"]

    my_loss = ["loss"]  #, "loss_seg"]

    logger = Logger(mylog_path=log_path,
                    mylog_name="training.log",
                    myloss_names=my_loss,
                    mymetric_names=my_metric)
    LP = LossPlotter(mylog_path=log_path,
                     mylog_name="training.log",
                     myloss_names=my_loss,
                     mymetric_names=my_metric)

    print("===> Logger and LossPlotter Initialized")

    ############################################################################################

    def checkpoint(epoch):
        w_path = 'weights/model-{:04d}.pt'.format(epoch)
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, join(log_path, w_path))
        print("===> Checkpoint saved to {}".format(w_path))

    #################################################################################################

    def train(epoch):

        model.train()

        scheduler.step()

        epoch_loss = 0

        seg_crit = nn.CrossEntropyLoss(reduction='none')
        reg_crit = nn.L1Loss(reduction='none')

        for iteration, batch in enumerate(tqdm(training_data_loader)):

            optimizer.zero_grad()

            inp, target = batch[0].to(device), batch[1].to(device)

            outp = model(inp)

            loss = 0

            loss_reg = reg_crit(outp[0], target)
            loss_reg = torch.mean(loss_reg)

            loss = loss_reg

            epoch_loss += loss.item()

            loss.backward()

            optimizer.step()

            # if iteration==10:
            #     break

        print("===> Training Epoch {} Complete: Avg. Loss: {:.4f}".format(
            epoch, epoch_loss / len(training_data_loader)))

    ###################################################################################################

    def test():

        model.eval()

        metric = np.zeros(len(my_metric) + len(my_loss))

        seg_crit = nn.CrossEntropyLoss(reduction='none')
        reg_crit = nn.L1Loss(reduction='none')

        with torch.no_grad():

            for iteration, batch in enumerate(training_data_loader):

                inp, target = batch[0].to(device), batch[1].to(
                    device)  #batch[1][1].to(device), target,

                outp = model(inp)

                loss = 0

                loss_reg = reg_crit(outp[0], target)
                loss_reg = torch.mean(loss_reg)

                loss = loss_reg

                loss = loss.item()

                mean_squared_error = torch.mean((outp[0] - target)**2).item()

                mean_absolute_error = torch.mean(torch.abs(outp[0] -
                                                           target)).item()

                metric += np.array(
                    [loss, mean_squared_error, mean_absolute_error])

                # if iteration==10:
                #     break

        return metric / len(training_data_loader)
        # return metric/10

    #####################################################################################################

    def validate():

        model.eval()

        metric = np.zeros(len(my_metric) + len(my_loss))

        seg_crit = nn.CrossEntropyLoss(reduction='none')
        reg_crit = nn.L1Loss(reduction='none')

        with torch.no_grad():

            for iteration, batch in enumerate(validation_data_loader):

                inp, target = batch[0].to(device), batch[1].to(
                    device)  #batch[1][1].to(device), target,

                outp = model(inp)

                loss = 0

                loss_reg = reg_crit(outp[0], target)
                loss_reg = torch.mean(loss_reg)

                loss = loss_reg

                loss = loss.item()

                mean_squared_error = torch.mean((outp[0] - target)**2).item()

                mean_absolute_error = torch.mean(torch.abs(outp[0] -
                                                           target)).item()

                metric += np.array(
                    [loss, mean_squared_error, mean_absolute_error])

                # if iteration==10:
                #     break

        return metric / len(validation_data_loader)
        # return metric/10

    #############################################################################################3
    #####################

    def visualize(epch):

        model.eval()

        v_path = join(log_path, 'visualize', '{}'.format(epch))

        os.makedirs(v_path, exist_ok=True)

        with torch.no_grad():

            for iteration, batch in enumerate(validation_data_loader):

                inp, target = batch[0].to(device), batch[1].to(
                    device)  #batch[1][1].to(device), target,

                outp = model(inp)

                target = np.squeeze(
                    target.data.cpu().numpy().astype('float32'))
                outp_reg = np.squeeze(
                    outp[0].data.cpu().numpy().astype('float32'))

                ####
                for i in range(inp.shape[0]):
                    result = Image.fromarray(
                        255 * target[i, ...].squeeze().astype(np.uint8))
                    result.save(join(v_path, 'depth{}.png'.format(i)))
                    result = Image.fromarray(
                        255 * outp_reg[i, ...].squeeze().astype(np.uint8))
                    result.save(join(v_path, 'outp_depth{}.png'.format(i)))

                if iteration == 0:
                    break

    #########################################################################################################

    total_params = sum(p.numel() for p in model.parameters())

    print("===> Starting Model Training at Epoch: {}".format(initial_epoch))
    print("===> Total Model Parameter: ", total_params)

    for epch in range(initial_epoch, num_epochs):

        start = time.time()

        print("\n\n")
        print("Epoch:{}".format(epch))
        train(epch)

        if epch % 2 == 0:

            train_metric = test()
            print(
                "===> Training   Epoch {}: Loss - {:.4f}, MSE - {:.4f}, MAE - {:.4f}"
                .format(epch, train_metric[0], train_metric[1],
                        train_metric[2]))

            val_metric = validate()
            print(
                "===> Validation Epoch {}: Loss - {:.4f}, MSE - {:.4f}, MAE - {:.4f}"
                .format(epch, val_metric[0], val_metric[1], val_metric[2]))

            logger.to_csv(np.concatenate((train_metric, val_metric)), epch)
            print("===> Logged All Metrics")

            LP.plotter()
            print("===> Plotted All Metrics")

            visualize(epch)
            print("===> Visualized some outputs")

        checkpoint(epch)
        end = time.time()
        print("===> Epoch:{} Completed in {} seconds".format(
            epch, end - start))

    print("===> Done Training for Total {} Epochs".format(num_epochs))
trainloader, validloader, testloader = get_data_loader(args.data, args.batchSize, args.normalize)
print("===> Data Loaders Initialized")


############################################################################
#
# Initialize Logger and LossPlotter
#
############################################################################

my_metric = ["D_fake", "D_real", "GP"]

my_loss = ["G_loss", "D_loss"]

logger = Logger(mylog_path=log_path, mylog_name="training.log", myloss_names=my_loss, mymetric_names=my_metric)
LP = LossPlotter(mylog_path=log_path, mylog_name="training.log", myloss_names=my_loss, mymetric_names=my_metric)

print("===> Logger and LossPlotter Initialized")


############################################################################
#
# define Gradient Penalty loss and checkpoint
#
############################################################################

def calc_gradient_penalty(real_d, fake_d):

    # Calculate interpolation
    alpha = torch.rand(args.batchSize,1,1,1)
    alpha = alpha.expand_as(real_d)