예제 #1
0
                          batch_size=cfg.BATCH_SIZE,
                          shuffle=True,
                          num_workers=cfg.WORKERS,
                          pin_memory=True,
                          drop_last=True)
val_loader = DataLoader(val_data,
                        batch_size=8,
                        shuffle=False,
                        num_workers=cfg.WORKERS,
                        pin_memory=True)
# val_loader=None
unlabeled_loader = None
num_train = len(train_data)
num_val = len(val_data)

cfg.CLASS_WEIGHTS_TRAIN = train_data.class_weights
cfg.IGNORE_LABEL = train_data.ignore_label

# shell script to run
print('LOSS_TYPES:', cfg.LOSS_TYPES)
writer = SummaryWriter(log_dir=cfg.LOG_PATH)  # tensorboard

if cfg.MULTI_MODAL:
    model = TRecgNet_MULTIMODAL(cfg, writer=writer)
else:
    model = TRecgNet(cfg, writer=writer)
model.set_data_loader(train_loader, val_loader, unlabeled_loader, num_train,
                      num_val)


def train():
예제 #2
0
        cfg,
        data_dir=cfg.DATA_DIR_VAL,
        transform=transforms.Compose([
            SPL10.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            SPL10.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            RGB2Lab(),
            SPL10.ToTensor(),
            SPL10.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
        ]))
    train_loader = DataProvider(cfg, dataset=train_dataset)
    val_loader = DataProvider(cfg, dataset=val_dataset, shuffle=False)

    num_classes_train = list(
        Counter([i[1] for i in train_loader.dataset.imgs]).values())
    cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train)

    model = Contrastive_CrossModal_Conc(cfg, device=device)
    model = nn.DataParallel(model).to(device)
    optim = Adam(model.parameters(), lr=cfg.LR)
    load_model = False
    if load_model:
        model = torch.load('./checkpoint/model_1_LAB.mdl')
        print("load pretrained model")
    # loss_optim = Adam(infomax_fn.parameters(), lr=2e-4)
    # cls_criterion = torch.nn.CrossEntropyLoss(cfg.CLASS_WEIGHTS_TRAIN.to(device))

    scheduler_optim = get_scheduler(optim)
    # scheduler_loss_optim = get_scheduler(loss_optim)

    epoch_restart = None
예제 #3
0
def main():
    cfg = DefaultConfig()
    args = {
        'seg_resnet_sunrgbd': SEG_RESNET_SUNRGBD_CONFIG().args(),
        'seg_resnet_cityscapes': SEG_RESNET_CITYSCAPE_CONFIG().args(),
        'rec_resnet_sunrgbd': REC_RESNET_SUNRGBD_CONFIG().args(),
        'rec_resnet_nyud2': REC_RESNET_NYUD2_CONFIG().args(),
        'rec_resnet_mit67': REC_RESNET_MIT67_CONFIG().args(),
        'infomax_resnet_sunrgbd': INFOMAX_RESNET_SUNRGBD_CONFIG().args(),
        'infomax_resnet_nyud2': INFOMAX_RESNET_NYUD2_CONFIG().args()
    }
    # use shell
    if len(sys.argv) > 1:
        device_ids = torch.cuda.device_count()
        print('device_ids:', device_ids)
        gpu_ids, config_key = sys.argv[1:]
        cfg.parse(args[config_key])
        cfg.GPU_IDS = gpu_ids.split(',')

    else:
        # seg_resnet_sunrgbd
        # seg_resnet_cityscapes
        # infomax_resnet_sunrgbd
        # rec_resnet_sunrgbd
        # rec_resnet_nyud2
        # rec_resnet_mit67
        # infomax_resnet_nyud2
        config_key = 'rec_resnet_sunrgbd'
        cfg.parse(args[config_key])
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            map(lambda x: str(x), cfg.GPU_IDS))

    trans_task = '' + cfg.WHICH_DIRECTION
    if not cfg.NO_TRANS:
        if cfg.MULTI_MODAL:
            trans_task = trans_task + '_multimodal_'

        if 'SEMANTIC' in cfg.LOSS_TYPES or 'PIX2PIX' in cfg.LOSS_TYPES:
            trans_task = trans_task + '_alpha_' + str(cfg.ALPHA_CONTENT)

    evaluate_type = 'sliding_window' if cfg.SLIDE_WINDOWS else 'center_crop'
    log_name = ''.join(
        [cfg.TASK, '_', cfg.ARCH, '_', trans_task, '_', cfg.DATASET])
    cfg.LOG_NAME = ''.join([
        log_name, '_', '.'.join(cfg.LOSS_TYPES), '_', evaluate_type, '_gpus_',
        str(len(cfg.GPU_IDS)), '_',
        datetime.now().strftime('%b%d_%H-%M-%S')
    ])
    cfg.LOG_PATH = os.path.join(cfg.LOG_PATH, cfg.MODEL, cfg.LOG_NAME)

    # Setting random seed
    if cfg.MANUAL_SEED is None:
        cfg.MANUAL_SEED = random.randint(1, 10000)
    random.seed(cfg.MANUAL_SEED)
    torch.manual_seed(cfg.MANUAL_SEED)
    torch.backends.cudnn.benchmark = True
    # cudnn.deterministic = True

    project_name = reduce(lambda x, y: str(x) + '/' + str(y),
                          os.path.realpath(__file__).split(os.sep)[:-1])
    print('>>> task path is {0}'.format(project_name))

    util.mkdir('logs')

    # dataset = segmentation_dataset_cv2
    train_transforms = list()
    val_transforms = list()
    ms_targets = []

    train_transforms.append(dataset.Resize(cfg.LOAD_SIZE))
    # train_transforms.append(dataset.RandomScale(cfg.RANDOM_SCALE_SIZE))  #
    # train_transforms.append(dataset.RandomRotate())
    # train_transforms.append(dataset.RandomCrop_Unaligned(cfg.FINE_SIZE, pad_if_needed=True, fill=0))  #
    train_transforms.append(
        dataset.RandomCrop(cfg.FINE_SIZE, pad_if_needed=True, fill=0))  #
    train_transforms.append(dataset.RandomHorizontalFlip())
    if cfg.TARGET_MODAL == 'lab':
        train_transforms.append(dataset.RGB2Lab())
    if cfg.MULTI_SCALE:
        for item in cfg.MULTI_TARGETS:
            ms_targets.append(item)
        train_transforms.append(
            dataset.MultiScale(size=cfg.FINE_SIZE,
                               scale_times=cfg.MULTI_SCALE_NUM,
                               ms_targets=ms_targets))
    train_transforms.append(dataset.ToTensor(ms_targets=ms_targets))
    train_transforms.append(
        dataset.Normalize(mean=cfg.MEAN, std=cfg.STD, ms_targets=ms_targets))

    val_transforms.append(dataset.Resize(cfg.LOAD_SIZE))
    if not cfg.SLIDE_WINDOWS:
        val_transforms.append(dataset.CenterCrop((cfg.FINE_SIZE)))

    if cfg.MULTI_SCALE:
        val_transforms.append(
            dataset.MultiScale(size=cfg.FINE_SIZE,
                               scale_times=cfg.MULTI_SCALE_NUM,
                               ms_targets=ms_targets))
    val_transforms.append(dataset.ToTensor(ms_targets=ms_targets))
    val_transforms.append(
        dataset.Normalize(mean=cfg.MEAN, std=cfg.STD, ms_targets=ms_targets))

    train_dataset = dataset.__dict__[cfg.DATASET](
        cfg=cfg,
        transform=transforms.Compose(train_transforms),
        data_dir=cfg.DATA_DIR_TRAIN,
        phase_train=True)
    val_dataset = dataset.__dict__[cfg.DATASET](
        cfg=cfg,
        transform=transforms.Compose(val_transforms),
        data_dir=cfg.DATA_DIR_VAL,
        phase_train=False)
    cfg.CLASS_WEIGHTS_TRAIN = train_dataset.class_weights
    cfg.IGNORE_LABEL = train_dataset.ignore_label

    cfg.train_dataset = train_dataset
    cfg.val_dataset = val_dataset

    port = random.randint(8001, 9000)
    ngpus_per_node = len(cfg.GPU_IDS)
    if cfg.MULTIPROCESSING_DISTRIBUTED:
        cfg.rank = 0
        cfg.ngpus_per_node = ngpus_per_node
        cfg.dist_url = 'tcp://127.0.0.1:' + str(port)
        cfg.dist_backend = 'nccl'
        cfg.opt_level = 'O0'
        cfg.world_size = 1

    cfg.print_args()

    if cfg.MULTIPROCESSING_DISTRIBUTED:
        cfg.world_size = cfg.ngpus_per_node * cfg.world_size
        mp.spawn(main_worker,
                 nprocs=cfg.ngpus_per_node,
                 args=(cfg.ngpus_per_node, cfg))
    else:
        # Simply call main_worker function
        main_worker(cfg.GPU_IDS, ngpus_per_node, cfg)
예제 #4
0
파일: train.py 프로젝트: chenbys/MTRecgNet
def train():
    cfg = DefaultConfig()
    args = {
        'resnet18': RESNET18_SUNRGBD_CONFIG().args(),
    }

    # Setting random seed
    if cfg.MANUAL_SEED is None:
        cfg.MANUAL_SEED = random.randint(1, 10000)
    random.seed(cfg.MANUAL_SEED)
    torch.manual_seed(cfg.MANUAL_SEED)

    # args for different backbones
    cfg.parse(args['resnet18'])

    os.environ["CUDA_VISIBLE_DEVICES"] = cfg.GPU_IDS
    device_ids = torch.cuda.device_count()
    print('device_ids:', device_ids)
    project_name = reduce(lambda x, y: str(x) + '/' + str(y),
                          os.path.realpath(__file__).split(os.sep)[:-1])
    util.mkdir('logs')

    # data
    train_dataset = dataset.AlignedConcDataset(
        cfg,
        data_dir=cfg.DATA_DIR_TRAIN,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.RandomHorizontalFlip(),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]))

    val_dataset = dataset.AlignedConcDataset(
        cfg,
        data_dir=cfg.DATA_DIR_VAL,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]))
    batch_size_val = cfg.BATCH_SIZE

    unlabeled_loader = None
    if cfg.UNLABELED:
        unlabeled_dataset = dataset.AlignedConcDataset(
            cfg,
            data_dir=cfg.DATA_DIR_UNLABELED,
            transform=transforms.Compose([
                dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
                dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
                dataset.RandomHorizontalFlip(),
                dataset.ToTensor(),
                dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ]),
            labeled=False)

        unlabeled_loader = DataProvider(cfg, dataset=unlabeled_dataset)

    train_loader = DataProvider(cfg,
                                dataset=train_dataset,
                                batch_size=batch_size_val)
    val_loader = DataProvider(cfg,
                              dataset=val_dataset,
                              batch_size=batch_size_val,
                              shuffle=False)

    # class weights
    num_classes_train = list(
        Counter([i[1] for i in train_loader.dataset.imgs]).values())
    cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train)

    writer = SummaryWriter(log_dir=cfg.LOG_PATH)  # tensorboard
    model = TRecgNet(cfg, writer=writer)
    model.set_data_loader(train_loader, val_loader, unlabeled_loader)
    if cfg.RESUME:
        checkpoint_path = os.path.join(cfg.CHECKPOINTS_DIR, cfg.RESUME_PATH)
        checkpoint = torch.load(checkpoint_path)
        load_epoch = checkpoint['epoch']
        model.load_checkpoint(model.net,
                              checkpoint_path,
                              checkpoint,
                              data_para=True)
        cfg.START_EPOCH = load_epoch

        if cfg.INIT_EPOCH:
            # just load pretrained parameters
            print('load checkpoint from another source')
            cfg.START_EPOCH = 1

    print('>>> task path is {0}'.format(project_name))

    # train
    model.train_parameters(cfg)

    print('save model ...')
    model_filename = '{0}_{1}_{2}.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION,
                                              cfg.NITER_TOTAL)
    model.save_checkpoint(cfg.NITER_TOTAL, model_filename)

    if writer is not None:
        writer.close()