コード例 #1
0
def main(cfg, gpu):
    torch.cuda.set_device(gpu)

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(arch=cfg.MODEL.arch_encoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(arch=cfg.MODEL.arch_decoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             num_class=cfg.DATASET.num_class,
                                             weights=cfg.MODEL.weights_decoder,
                                             use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

    # Dataset and Loader
    dataset_test = TestDataset(cfg.list_test, cfg.DATASET)
    loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=cfg.TEST.batch_size,
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=5,
        drop_last=True)

    segmentation_module.cuda()

    # Main loop
    test(segmentation_module, loader_test, gpu)

    print('Inference done!')
コード例 #2
0
    def model_builder(self, imode="sphe"):
        if imode == "sphe":
            # Network Builders
            net_encoder = seg_sphe.ModelBuilder.build_encoder(
                arch='resnet50dilated',
                fc_dim=2048,
                weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')
            net_decoder = seg_sphe.ModelBuilder.build_decoder(
                arch='ppm_deepsup',
                fc_dim=2048,
                num_class=150,
                weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',
                use_softmax=True)
        elif imode == "persp":
            # encoder_epoch = 'ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth'
            # decoder_epoch = 'ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth'
            encoder_epoch = 'ckpt_nef/r50d_ppm_rot_e40_nef_30/encoder_epoch_40.pth'
            decoder_epoch = 'ckpt_nef/r50d_ppm_rot_e40_nef_30/decoder_epoch_40.pth'
            net_encoder = seg_persp.ModelBuilder.build_encoder(
                arch='resnet50dilated',
                fc_dim=2048,
                weights=encoder_epoch)
            net_decoder = seg_persp.ModelBuilder.build_decoder(
                arch='ppm_deepsup',
                fc_dim=2048,
                num_class=150,
                weights=decoder_epoch,
                use_softmax=True)

        crit = torch.nn.NLLLoss(ignore_index=-1)
        semseg_model = SegmentationModule(net_encoder, net_decoder, crit)
        semseg_model.eval()
        semseg_model.cuda()

        return semseg_model
コード例 #3
0
def worker(cfg, gpu_id, start_idx, end_idx, result_queue):
    torch.cuda.set_device(gpu_id)

    # Dataset and Loader
    dataset_val = ValDataset(cfg.DATASET.root_dataset,
                             cfg.DATASET.list_val,
                             cfg.DATASET,
                             start_idx=start_idx,
                             end_idx=end_idx)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=cfg.VAL.batch_size,
                                             shuffle=False,
                                             collate_fn=user_scattered_collate,
                                             num_workers=2)

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder,
        use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

    segmentation_module.cuda()

    # Main loop
    evaluate(segmentation_module, loader_val, cfg, gpu_id, result_queue)
コード例 #4
0
ファイル: semantic_segmentation.py プロジェクト: jakobfp/ml4a
def setup(gpu):
    global segmentation_module

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch='resnet50dilated',
        fc_dim=2048,
        weights=downloads.download_data_file(
            'http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth',
            'semantic-segmentation-pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth'
        ))

    net_decoder = ModelBuilder.build_decoder(
        arch='ppm_deepsup',
        fc_dim=2048,
        num_class=150,
        weights=downloads.download_data_file(
            'http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',
            'semantic-segmentation-pytorch/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth'
        ),
        use_softmax=True)

    crit = torch.nn.NLLLoss(ignore_index=-1)
    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
    segmentation_module.eval()
    segmentation_module.cuda()
コード例 #5
0
    def _load_model(self, cfg_file, gpu):
        if gpu is not None:
            torch.cuda.set_device(0)
        basepath = rospkg.RosPack().get_path('object_segmentation')
        cfg.merge_from_file(basepath + "/" + cfg_file)

        logger = setup_logger(distributed_rank=0)
        logger.info(f"Loaded configuration file {cfg_file}")
        logger.info("Running with config:\n{}".format(cfg))

        cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower()
        cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower()

        # absolute paths of model weights
        cfg.MODEL.weights_encoder = (
            Path(basepath) / cfg.DIR /
            ('encoder_' + cfg.TEST.checkpoint)).as_posix()
        cfg.MODEL.weights_decoder = (
            Path(basepath) / cfg.DIR /
            ('decoder_' + cfg.TEST.checkpoint)).as_posix()

        if not os.path.exists(cfg.MODEL.weights_encoder) or not os.path.exists(
                cfg.MODEL.weights_decoder):
            download.ycb(Path(basepath) / 'ckpt')

        assert os.path.exists(
            cfg.MODEL.weights_encoder
        ), f"checkpoint {cfg.MODEL.weights_encoder} does not exitst!"
        assert os.path.exists(
            cfg.MODEL.weights_decoder
        ), f"checkpoint {cfg.MODEL.weights_decoder} does not exitst!"

        # Network Builders
        net_encoder = ModelBuilder.build_encoder(
            arch=cfg.MODEL.arch_encoder,
            fc_dim=cfg.MODEL.fc_dim,
            weights=cfg.MODEL.weights_encoder)
        net_decoder = ModelBuilder.build_decoder(
            arch=cfg.MODEL.arch_decoder,
            fc_dim=cfg.MODEL.fc_dim,
            num_class=cfg.DATASET.num_class,
            weights=cfg.MODEL.weights_decoder,
            use_softmax=True)

        crit = nn.NLLLoss(ignore_index=-1)

        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)
        if self.gpu is not None:
            segmentation_module.cuda()
        segmentation_module.eval()
        self.model = segmentation_module
コード例 #6
0
def semantic_segmentation(image_folder):
    download_segmentation()
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch='resnet101dilated',
        fc_dim=2048,
        weights='data/segm_data/encoder_epoch_25.pth')
    net_decoder = ModelBuilder.build_decoder(
        arch='ppm_deepsup',
        fc_dim=2048,
        num_class=150,
        weights='data/segm_data/decoder_epoch_25.pth',
        use_softmax=True)

    crit = torch.nn.NLLLoss(ignore_index=-1)
    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
    segmentation_module.eval()
    segmentation_module.cuda()

    # Load and normalize one image as a singleton tensor batch
    pil_to_tensor = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # These are RGB mean+std values
            std=[0.229, 0.224, 0.225])  # across a large photo dataset.
    ])
    pil_image = PIL.Image.open(os.path.join(image_folder,
                                            '000000.jpg')).convert('RGB')
    img_original = np.array(pil_image)
    img_data = pil_to_tensor(pil_image)
    singleton_batch = {'img_data': img_data[None].cuda()}
    output_size = img_data.shape[1:]

    # Run the segmentation at the highest resolution.
    with torch.no_grad():
        scores = segmentation_module(singleton_batch, segSize=output_size)

    # Get the predicted scores for each pixel
    _, pred = torch.max(scores, dim=1)
    pred = pred.cpu()[0].numpy()
    del net_encoder, net_decoder, segmentation_module
    return pred
コード例 #7
0
def load_model_from_url(semsegPath):

    model_urls = {
        'encoder':
        'http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth',
        'decoder':
        'http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth'
    }

    if not os.path.exists(os.path.join(semsegPath, 'ckpt')):
        os.makedirs(os.path.join(semsegPath, 'ckpt'))

    r = requests.get(model_urls['encoder'])
    with open(os.path.join(semsegPath, 'ckpt/encoder.pth'), 'wb') as f:
        f.write(r.content)
    r = requests.get(model_urls['decoder'])
    with open(os.path.join(semsegPath, 'ckpt/decoder.pth'), 'wb') as f:
        f.write(r.content)

    net_encoder = ModelBuilder.build_encoder(arch='resnet50dilated',
                                             fc_dim=2048,
                                             weights=os.path.join(
                                                 semsegPath,
                                                 'ckpt/encoder.pth'))
    net_decoder = ModelBuilder.build_decoder(arch='ppm_deepsup',
                                             fc_dim=2048,
                                             num_class=150,
                                             weights=os.path.join(
                                                 semsegPath,
                                                 'ckpt/decoder.pth'),
                                             use_softmax=True)

    crit = torch.nn.NLLLoss(ignore_index=-1)
    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
    segmentation_module.eval()
    segmentation_module.cuda()

    return segmentation_module
コード例 #8
0
ファイル: detect.py プロジェクト: callaunchpad/Watchman
def create_network(cfg, gpu):
    torch.cuda.set_device(gpu)

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(arch=cfg.MODEL.arch_encoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(arch=cfg.MODEL.arch_decoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             num_class=cfg.DATASET.num_class,
                                             weights=cfg.MODEL.weights_decoder,
                                             use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

    segmentation_module.cuda()

    # Main loop
    return segmentation_module, gpu

    print('Inference done!')
コード例 #9
0
ファイル: pipeline2.py プロジェクト: toshi2k2/asco_perception
def segmentation_model_init():
    if not os.path.exists(model_folder):
        os.makedirs(model_folder)

    ENCODER_NAME = 'resnet101'
    DECODER_NAME = 'upernet'
    PRETRAINED_ENCODER_MODEL_URL = 'http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-%s-%s/encoder_epoch_50.pth' % (
        ENCODER_NAME, DECODER_NAME)
    PRETRAINED_DECODER_MODEL_URL = 'http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-%s-%s/decoder_epoch_50.pth' % (
        ENCODER_NAME, DECODER_NAME)

    pretrained_encoder_file = ENCODER_NAME + basename(
        PRETRAINED_ENCODER_MODEL_URL)
    pretrained_decoder_file = DECODER_NAME + basename(
        PRETRAINED_DECODER_MODEL_URL)
    encoder_path = os.path.join(model_folder, pretrained_encoder_file)
    decoder_path = os.path.join(model_folder, pretrained_decoder_file)

    if not os.path.exists(encoder_path):
        urllib.request.urlretrieve(PRETRAINED_ENCODER_MODEL_URL, encoder_path)
    if not os.path.exists(decoder_path):
        urllib.request.urlretrieve(PRETRAINED_DECODER_MODEL_URL, decoder_path)

    # options
    options = SimpleNamespace(fc_dim=2048,
                              num_class=150,
                              imgSizes=[300, 400, 500, 600],
                              imgMaxSize=1000,
                              padding_constant=8,
                              segm_downsampling_rate=8)

    # create model
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=ENCODER_NAME,
                                        weights=encoder_path,
                                        fc_dim=options.fc_dim)
    net_decoder = builder.build_decoder(arch=DECODER_NAME,
                                        weights=decoder_path,
                                        fc_dim=options.fc_dim,
                                        num_class=options.num_class,
                                        use_softmax=True)
    segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                             torch.nn.NLLLoss(ignore_index=-1))
    segmentation_module = segmentation_module.eval()
    torch.set_grad_enabled(False)

    if torch.cuda.is_available():
        segmentation_module = segmentation_module.cuda()
    return segmentation_module, options
コード例 #10
0
def huawei_seg(imgs):

    parser = argparse.ArgumentParser(
        description="PyTorch Semantic Segmentation Testing")
    parser.add_argument(
        "--cfg",
        default="config/ade20k-hrnetv2-huawei.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--gpu",
                        default=0,
                        type=int,
                        help="gpu id for evaluation")

    args = parser.parse_args()

    cfg.merge_from_file(args.cfg)

    cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower()
    cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower()

    # absolute paths of model weights
    cfg.MODEL.weights_encoder = os.path.join(cfg.DIR,
                                             'encoder_' + cfg.TEST.checkpoint)
    cfg.MODEL.weights_decoder = os.path.join(cfg.DIR,
                                             'decoder_' + cfg.TEST.checkpoint)
    #
    imgs = [imgs]
    cfg.list_test = [{'fpath_img': x} for x in imgs]

    torch.cuda.set_device(args.gpu)

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(arch=cfg.MODEL.arch_encoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(arch=cfg.MODEL.arch_decoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             num_class=cfg.DATASET.num_class,
                                             weights=cfg.MODEL.weights_decoder,
                                             use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

    # Dataset and Loader
    dataset_test = InferDataset(cfg.list_test, cfg.DATASET)
    loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=1,
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=5,
        drop_last=True)

    segmentation_module.cuda()
    loader = loader_test
    # Main loop
    segmentation_module.eval()
    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, args.gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                # print(pred_tmp.shape)#torch.Size([1, 2, 1024, 1024])
                scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         cfg)
        pbar.update(1)
    #
    return pred
コード例 #11
0
def main(cfg, gpus):
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    crit = nn.NLLLoss(ignore_index=-1)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit,
                                                 cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit)

    # Dataset and Loader
    dataset_train = TrainDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_train,
                                 cfg.DATASET,
                                 batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)

    # load nets into gpu
    print(gpus)
    # if len(gpus) > 1:
    if True:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)

        # checkpointing
        checkpoint(nets, history, cfg, epoch + 1)

    print('Training Done!')
コード例 #12
0
# hrnet
net_encoder = ModelBuilder.build_encoder(
    arch='hrnetv2',
    fc_dim=720,
    weights='../../ckpt/ade20k-hrnetv2-c1/encoder_epoch_30.pth')
net_decoder = ModelBuilder.build_decoder(
    arch='c1',
    fc_dim=720,
    num_class=150,
    weights='../../ckpt/ade20k-hrnetv2-c1/decoder_epoch_30.pth',
    use_softmax=True)

crit = torch.nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
segmentation_module.eval()
segmentation_module.cuda()

pil_to_tensor = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # These are RGB mean+std values
        std=[0.229, 0.224, 0.225])  # across a large photo dataset.
])
pil_image = PIL.Image.open('sample.png').convert('RGB')
img_original = numpy.array(pil_image)
img_data = pil_to_tensor(pil_image)
singleton_batch = {'img_data': img_data[None].cuda()}
output_size = img_data.shape[1:]

with torch.no_grad():
    scores = segmentation_module(singleton_batch, segSize=output_size)
コード例 #13
0
def main(cfg, gpus):
    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder)

    if cfg.MODEL.arch_decoder.endswith('regression'):
        crit = nn.MSELoss(
            reduction="sum"
        )  # Sum for multi-output learning, need to sum across all labels
    else:
        crit = nn.NLLLoss(ignore_index=-1)  # negative log likelihood loss

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg.DATASET.classes,
                                                 cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder,
                                                 crit, cfg.DATASET.classes)

    print("net_encoder")
    print(type(net_encoder))
    print(net_encoder)
    print("net_decoder")
    print(type(net_decoder))
    print(net_decoder)

    # Dataset and Loader
    if cfg.MODEL.arch_decoder.endswith('regression'):
        print("performing regression")
        dataset_train = TrainDatasetRegression(
            cfg.DATASET.root_dataset,
            cfg.DATASET.list_train,
            cfg.DATASET.classes,
            cfg.DATASET,
            batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

        dataset_val = TrainDatasetRegression(
            cfg.DATASET.root_dataset,
            cfg.DATASET.list_val,
            cfg.DATASET.classes,
            cfg.DATASET,
            batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)
    else:
        dataset_train = TrainDataset(
            cfg.DATASET.root_dataset,
            cfg.DATASET.list_train,
            cfg.DATASET,
            batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)
    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    # create loader iterator
    iterator_train = iter(loader_train)
    iterator_val = iter(loader_val)

    # load nets into gpu
    if len(gpus) > 1:
        segmentation_module = UserScatteredDataParallel(segmentation_module,
                                                        device_ids=gpus)
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    # Main loop
    history = {
        'train': {
            'epoch': [],
            'loss': []
        },
        'val': {
            'epoch': [],
            'loss': []
        }
    }

    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        train(segmentation_module, iterator_train, optimizers, history,
              epoch + 1, cfg)
        val(segmentation_module, iterator_val, optimizers, history, epoch + 1,
            cfg)

        # checkpointing every 5th epoch
        if (epoch % 5 == 0):
            checkpoint(nets, history, cfg, epoch + 1)

    print('Training Done!')