def main():
    net = DenseASPP(model_cfg=DenseASPP121.Model_CFG).cuda()
    # densenet121 = models.densenet121(pretrained=True)

    if len(args.checkpoint_path) == 0:
        curr_epoch = 1
        # Initializing 'best_record'
        args.best_record = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        # load the pretrained model
        print('training resumes from ' + args.checkpoint_path)
        # lambda argument: manipulate(argument)
        # pretrained_weight = torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage)
        pretrained_weight = torch.load(args.checkpoint_path)
        new_state_dict = OrderedDict()
        model_dict = net.state_dict()
        for key, value in pretrained_weight.items():
            name = key
            # print(name)
            new_state_dict[name] = value

        new_state_dict.pop('features.norm5.weight')
        new_state_dict.pop('features.norm5.bias')
        new_state_dict.pop('features.norm5.running_mean')
        new_state_dict.pop('features.norm5.running_var')
        new_state_dict.pop('classifier.weight')
        new_state_dict.pop('classifier.bias')
        # print(new_state_dict)
        model_dict.update(new_state_dict)
        net.load_state_dict(model_dict, strict=False)
        # pretrained_dict = {key: value for key, value in pretrained_dict.items() if key in model_dict}
        # model_dict.update(pretrained_dict)
        # pretrained_dict = {key: value for key, value in pretrained_dict.items() if key != 'classifier.weight' or 'classifier.bias'}
        # for key, value in pretrained_dict.items():
        #     if 'classifier.weight' in key:
        #         key = key.rstrip('classifier.weight')
        #     if 'classifier.bias' in key:
        #         key = key.rstrip('classifier.bias')
        #     if 'features.norm5.weight' in key:
        #         key = key.replace("features.norm5.weight", "")
        #     elif 'features.norm5.bias' in key:
        #         key = key.strip('features.norm5.bias')
        #     elif 'features.norm5.running_mean' in key:
        #         key = key.strip('features.norm5.running_mean')
        #     elif 'features.norm5.running_var' in key:
        #         key = key.strip('features.norm5.running_var')
        #     name = key
        #     print(name)
        #     new_pretrained_dict[name] = value

        # model.load_state_dict(model_dict, strict=False)
        # model.load_state_dict(new_pretrained_dict, strict=False)
        curr_epoch = 1
        args.best_record = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
        # split_snapshot = args.checkpoint_path.split('_')
        # curr_epoch = int(split_snapshot[1]) + 1
        # args.best_record = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
        #                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
        #                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    cudnn.benchmark = True

    # net.train()
    # tells your model that you are training the model ,or sets the module in training mode.
    # The classic workflow : call train() --> epoch of training on the training set --> call eval()
    # --> evaluate your model on the validation set

    # mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    mean_std = ([0.290101, 0.328081, 0.286964], [0.182954, 0.186566, 0.184475])

    # ---------------------------------- [[ data - augmentation ]] ---------------------------------------------------
    # ----------------------------------------------------------------------------------------------------------------
    # [[joint_transforms]]
    # both raw image and gt are transformed by data-augmentation
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.RandomSizedCrop(size=args.input_width),
        # joint_transforms.RandomSized(size=args.input_width),
        joint_transforms.RandomHorizontallyFlip()
    ])

    val_joint_transform = joint_transforms.Compose([
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomSizedCrop(size=args.input_width)
    ])

    # transform : To preprocess images
    # Compose : if there are a lot of preprocessed images, compose plays a role as collector in a single space.
    input_transform = standard_transforms.Compose([
        standard_transforms.ColorJitter(hue=0.1),
        # colorjitter.ColorJitter(brightness=0.1),
        standard_transforms.ToTensor(),
        # standard_transforms.Normalize(*mean_std)
    ])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ColorJitter(hue=0.1),
        standard_transforms.ToTensor(),
        # standard_transforms.Normalize(*mean_std)
    ])

    # target_transform = extended_transforms.MaskToTensor()
    target_transform = extended_transforms.Compose([
        extended_transforms.MaskToTensor(),
    ])

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        # [[ ToPILImage() ]]
        # Convert a tensor or an ndarray to PIL Image.
        standard_transforms.ToPILImage()
    ])
    visualize = standard_transforms.ToTensor()
    # -----------------------------------------------------------------------------------------------------------------
    # -----------------------------------------------------------------------------------------------------------------

    train_set = segmentation_dataloader.CityScapes(
        'fine',
        'train',
        joint_transform=train_joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args.train_batch_size,
                              num_workers=args.num_threads,
                              shuffle=True)

    val_set = segmentation_dataloader.CityScapes(
        'fine',
        'val',
        joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args.val_batch_size,
                            num_workers=args.num_threads,
                            shuffle=False)

    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=segmentation_dataloader.ignore_label)
    optimizer = optim.Adam(net.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)

    # optimizer = optim.Adam([
    #     {   'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
    #         'lr': args.learning_rate}
    # ], weight_decay=args.weight_decay)

    # if len(args.checkpoint_path) > 0:
    #     optimizer.load_state_dict(torch.load(args.checkpoint_path))
    #     optimizer.param_groups[0]['lr'] = 2 * args.learning_rate
    #     optimizer.param_groups[1]['lr'] = args.learning_rate

    check_mkdir(ckpt_path)
    # check_mkdir(os.path.join(ckpt_path, exp_name))
    check_mkdir(os.path.join(ckpt_path, 'Model', ImageNet, exp_name_ImageNet))
    # open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')
    open(
        os.path.join(ckpt_path, 'Model', ImageNet, exp_name_ImageNet,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')
    # lambda1 = lambda epoch : (1 - epoch // args['epoch_num']) ** 0.9

    # [[learning-rate decay]]
    # factor = (1 - curr_epoch/ args['epoch_num'])**0.9
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor= factor, patience=args['lr_patience'],
    #                                                  min_lr=0)

    for epoch in range(curr_epoch, args.num_epochs + 1):
        # factor = (1 - epoch / args['epoch_num']) ** 0.9
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=args['lr_patience'],
        #                                                  min_lr=0)

        # [[ training ]]
        train(train_loader, net, criterion, optimizer, epoch, args, train_set)
        # train(train_loader, net, optimizer, epoch, args, train_set)
        # [[ validation ]]
        validate(val_loader, net, criterion, optimizer, epoch, args,
                 restore_transform, visualize)
        # validate(val_loader, net, optimizer, epoch, args, restore_transform, visualize)
        # scheduler.step(val_loss)

    print('Training Done!!')
Exemple #2
0
def main():
    net = DenseASPP_boundary_depthwise(model_cfg=DenseASPP121.Model_CFG).cuda()
    # net = DenseASPP(model_cfg=DenseASPP121.Model_CFG).cuda()
    # densenet121 = models.densenet121(pretrained=True)
    if len(args.checkpoint_path) == 0:
        curr_epoch = 1
        # Initializing 'best_record'
        args.best_record = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        # load the pretrained model
        print('training resumes from ' + args.checkpoint_path)
        # lambda ==> argument: manipulate(argument)
        pretrained_weight = torch.load(
            args.checkpoint_path, map_location=lambda storage, loc: storage)
        """ map_location = lambda storage, loc: storage--> Load all tensors onto the CPU, using a function"""
        new_state_dict = OrderedDict()
        model_dict = net.state_dict()
        for key, value in pretrained_weight.items():
            name = key
            new_state_dict[name] = value
            if name.find('norm') >= 9:
                # print('norm contained from pretrained_weight : ', name)
                value.requires_grad = False
            # if name.find('conv0') >= 9:
            #     print('norm contained from pretrained_weight : ', name)
            #     value.requires_grad = False

        new_state_dict.pop('features.conv0.weight')
        new_state_dict.pop('features.norm5.weight')
        new_state_dict.pop('features.norm5.bias')
        new_state_dict.pop('features.norm5.running_mean')
        new_state_dict.pop('features.norm5.running_var')
        new_state_dict.pop('classifier.weight')
        new_state_dict.pop('classifier.bias')
        model_dict.update(new_state_dict)
        net.load_state_dict(model_dict)
        # pretrained_dict = {key: value for key, value in pretrained_dict.items() if key in model_dict}
        # model_dict.update(pretrained_dict)
        # pretrained_dict = {key: value for key, value in pretrained_dict.items() if key != 'classifier.weight' or 'classifier.bias'}

        # model.load_state_dict(model_dict, strict=False)
        # model.load_state_dict(new_pretrained_dict, strict=False)
        curr_epoch = 1
        args.best_record = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }

    # ---------------------------------- [[ data - augmentation ]] ---------------------------------------------------
    # ----------------------------------------------------------------------------------------------------------------
    # [[joint_transforms]]
    # both raw image and gt are transformed by data-augmentation
    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.ImageScaling(size=[0.5, 2.0]),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomSizedCrop(size=args.input_width),
    ])

    # transform : To preprocess images
    # Compose : if there are a lot of preprocessed images, compose plays a role as collector in a single space.
    input_transform = standard_transforms.Compose([
        # Colorjitter.ColorJitter(brightness=[-10, 10]),
        standard_transforms.ColorJitter(hue=0.1),
        standard_transforms.ToTensor(),
        # standard_transforms.Normalize(*my_mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()

    train_set = segmentation_dataloader.CityScapes(
        'fine',
        'train',
        joint_transform=train_joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args.train_batch_size,
                              num_workers=args.num_threads,
                              shuffle=True)

    # optimizer = optim.Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=segmentation_dataloader.ignore_label).cuda()

    num_training_samples = len(train_set)
    steps_per_epoch = np.ceil(num_training_samples /
                              args.train_batch_size).astype(np.int32)
    num_total_steps = args.num_epochs * steps_per_epoch

    print("total number of samples: {}".format(num_training_samples))
    print("total number of steps  : {}".format(num_total_steps))

    # COUNT_PARAMS
    total_num_paramters = 0
    for param in net.parameters():
        total_num_paramters += np.array(list(param.size())).prod()

    print("number of trainable parameters: {}".format(total_num_paramters))

    for epoch in range(curr_epoch, args.num_epochs + 1):
        lr_ = poly_lr_scheduler(init_lr=args.learning_rate, epoch=epoch - 1)
        optimizer = optim.Adam(net.parameters(),
                               lr=lr_,
                               weight_decay=args.weight_decay)

        train(train_loader, net, criterion, optimizer, epoch, args,
              total_num_paramters)

    torch.save(
        net.state_dict(),
        os.path.join(ckpt_path, 'Model', ImageNet, exp_name_ImageNet,
                     'model-{}'.format(total_num_paramters) + '.pkl'))
    print('Training Done!!')
def main():
    net = DenseASPP(model_cfg=DenseASPP121.Model_CFG).cuda()
    # densenet121 = models.densenet121(pretrained=True)
    if len(args.checkpoint_path) == 0:
        curr_epoch = 1
        # Initializing 'best_record'
        args.best_record = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        # load the pretrained model
        print('training resumes from ' + args.checkpoint_path)
        # lambda ==> argument: manipulate(argument)
        pretrained_weight = torch.load(
            args.checkpoint_path, map_location=lambda storage, loc: storage)
        """ map_location = lambda storage, loc: storage--> Load all tensors onto the CPU, using a function"""
        new_state_dict = OrderedDict()
        model_dict = net.state_dict()
        for key, value in pretrained_weight.items():
            name = key
            new_state_dict[name] = value
            if name.find('norm') >= 9:
                print('norm contained from pretrained_weight : ', name)
                value.requires_grad = False
            # if name.find('conv0') >= 9:
            #     print('norm contained from pretrained_weight : ', name)
            #     value.requires_grad = False

        new_state_dict.pop('features.norm5.weight')
        new_state_dict.pop('features.norm5.bias')
        new_state_dict.pop('features.norm5.running_mean')
        new_state_dict.pop('features.norm5.running_var')
        new_state_dict.pop('classifier.weight')
        new_state_dict.pop('classifier.bias')
        model_dict.update(new_state_dict)
        net.load_state_dict(model_dict)
        # pretrained_dict = {key: value for key, value in pretrained_dict.items() if key in model_dict}
        # model_dict.update(pretrained_dict)
        # pretrained_dict = {key: value for key, value in pretrained_dict.items() if key != 'classifier.weight' or 'classifier.bias'}

        # model.load_state_dict(model_dict, strict=False)
        # model.load_state_dict(new_pretrained_dict, strict=False)
        curr_epoch = 1
        args.best_record = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }

    # cudnn.benchmark = True
    # net.train()
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    my_mean_std = ([0.688824, 0.270588,
                    0.305882], [0.041225, 0.032375, 0.026271])

    # ---------------------------------- [[ data - augmentation ]] ---------------------------------------------------
    # ----------------------------------------------------------------------------------------------------------------
    # [[joint_transforms]]
    # both raw image and gt are transformed by data-augmentation
    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.ImageScaling(size=[0.5, 2.0]),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomSizedCrop(size=args.input_width),
    ])

    val_joint_transform = joint_transforms.Compose([
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomSizedCrop(size=args.input_width),
    ])
    """random_gamma = tf.random_uniform([], 0.8, 1.2)
    left_image_aug = left_image ** random_gamma
    right_image_aug = right_image ** random_gamma

    # randomly shift brightness
    random_brightness = tf.random_uniform([], 0.5, 2.0)
    left_image_aug = left_image_aug * random_brightness
    right_image_aug = right_image_aug * random_brightness"""

    # transform : To preprocess images
    # Compose : if there are a lot of preprocessed images, compose plays a role as collector in a single space.
    input_transform = standard_transforms.Compose([
        # Colorjitter.ColorJitter(brightness=[-10, 10]),
        standard_transforms.ColorJitter(hue=0.1),
        standard_transforms.ToTensor(),
        # standard_transforms.Normalize(*my_mean_std)
    ])

    val_input_transform = standard_transforms.Compose([
        standard_transforms.ColorJitter(hue=0.1),
        standard_transforms.ToTensor(),
        # standard_transforms.Normalize(*my_mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        # """[[ ToPILImage() ]]"""
        # Convert a tensor or an ndarray to PIL Image.
        standard_transforms.ToPILImage()
    ])
    visualize = standard_transforms.ToTensor()
    # -----------------------------------------------------------------------------------------------------------------
    # -----------------------------------------------------------------------------------------------------------------
    train_set = segmentation_dataloader.CityScapes(
        'fine',
        'train',
        joint_transform=train_joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args.train_batch_size,
                              num_workers=args.num_threads,
                              shuffle=True)

    val_set = segmentation_dataloader.CityScapes(
        'fine',
        'val',
        joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args.val_batch_size,
                            num_workers=args.num_threads,
                            shuffle=False)

    # optimizer = optim.Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=segmentation_dataloader.ignore_label).cuda()
    for epoch in range(curr_epoch, args.num_epochs + 1):
        # net.train()

        lr_ = poly_lr_scheduler(init_lr=args.learning_rate, epoch=epoch - 1)
        optimizer = optim.Adam(net.parameters(),
                               lr=lr_,
                               weight_decay=args.weight_decay)

        train(train_loader, net, criterion, optimizer, epoch, args, train_set)
        net.eval()
        validate(val_loader, net, criterion, optimizer, epoch, args,
                 restore_transform, visualize)

    print('Training Done!!')