Пример #1
0
def worker(cfg, gpu_id, start_idx, end_idx, result_queue):
    torch.cuda.set_device(gpu_id)

    # Dataset and Loader
    dataset_val = ValDataset(cfg.DATASET.root_dataset,
                             cfg.DATASET.list_val,
                             cfg.DATASET,
                             start_idx=start_idx,
                             end_idx=end_idx)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=cfg.VAL.batch_size,
                                             shuffle=False,
                                             collate_fn=user_scattered_collate,
                                             num_workers=2)

    # Network Builders
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query)
    net_enc_memory = ModelBuilder.build_encoder_memory(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_memory,
        num_class=cfg.DATASET.num_class)
    net_att_query = ModelBuilder.build_encoder(
        arch='attention',
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_encoder(
        arch='attention',
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder,
        use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationAttentionModule(
        net_enc_query,
        net_enc_memory,
        net_att_query,
        net_att_memory,
        net_decoder,
        crit,
        normalize_key=cfg.MODEL.normalize_key,
        p_scalar=cfg.MODEL.p_scalar)

    segmentation_module.cuda()

    # Main loop
    evaluate(segmentation_module, loader_val, cfg, gpu_id, result_queue)
Пример #2
0
def main(cfg, gpus):
    # Network Builders
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query)
    net_enc_memory = ModelBuilder.build_encoder_memory(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_memory,
        num_class=cfg.DATASET.num_class)
    net_att_query = ModelBuilder.build_encoder(
        arch='attention',
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_encoder(
        arch='attention',
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationAttentionModule(
            net_enc_query,
            net_enc_memory,
            net_att_query,
            net_att_memory,
            net_decoder,
            crit,
            cfg.TRAIN.deep_sup_scale,
            normalize_key=cfg.MODEL.normalize_key,
            p_scalar=cfg.MODEL.p_scalar)
    else:
        segmentation_module = SegmentationAttentionModule(
            net_enc_query,
            net_enc_memory,
            net_att_query,
            net_att_memory,
            net_decoder,
            crit,
            normalize_key=cfg.MODEL.normalize_key,
            p_scalar=cfg.MODEL.p_scalar)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 cfg.DATASET.ref_path,
                                 cfg.DATASET.ref_start,
                                 cfg.DATASET.ref_end,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    '''if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)'''
    segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                    device_ids=gpus)
    # For sync bn
    patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory,
            net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)

        # checkpointing
        if (epoch + 1) % cfg.TRAIN.save_freq == 0:
            checkpoint(nets, history, cfg, epoch + 1)

    print('Training Done!')