コード例 #1
0
def main():
    dataset = data_loader.ColorizeImageNet(root,
                                           split='val',
                                           set='small',
                                           bins='soft',
                                           num_hc_bins=16,
                                           gmm_path=GMM_PATH,
                                           mean_l_path=MEAN_L_PATH)
    img, labels = dataset.__getitem__(0)
    gmm = dataset.gmm
    mean_l = dataset.mean_l

    img_file = dataset.files['val'][1]
    im_orig = skimage.io.imread(img_file)

    # ... predicted labels and input image (mean subtracted)
    labels = labels.numpy()
    img = img.squeeze().numpy()
    im_rgb = utils.colorize_image_hc(labels, img, gmm, mean_l)

    plt.imshow(im_rgb)
    plt.show()

    #
    inputs = Variable(img)
    if cuda:
        inputs = inputs.cuda()
    outputs = model(inputs)
コード例 #2
0
def test_single_read():
    print 'Entering: test_single_read'
    dataset = data_loader.ColorizeImageNet(root, split='train', set='small')
    img, lbl = dataset.__getitem__(0)
    assert len(lbl) == 2
    assert np.min(lbl[0].numpy()) == 0
    assert np.max(lbl[0].numpy()) == 30
    print 'Test passed: test_single_read'
コード例 #3
0
def test_lowpass_image():
    dataset = \
        data_loader.ColorizeImageNet(root, split='train', set='small',
                                     bins='soft', img_lowpass=8)
    img, lbl = dataset.__getitem__(0)
    assert type(lbl) == torch.FloatTensor
    assert type(img) == torch.FloatTensor
    print 'Test passed: test_soft_bins'
コード例 #4
0
def test_rgb_hsv():
    # DEFER
    dataset = data_loader.ColorizeImageNet(\
                root, split='train', set='small')
    img_file = dataset.files['train'][100]
    img = PIL.Image.open(img_file)
    img = np.array(img, dtype=np.uint8)
    assert np.max(img.shape) == 400
コード例 #5
0
def test_init_gmm():
    # Pass paths to cached GMM and mean Lightness
    GMM_PATH = '/srv/data1/arunirc/Research/colorize-fcn/colorizer-fcn/logs/MODEL-fcn32s_color_CFG-014_VCS-db517d6_TIME-20171230-212406/gmm.pkl'
    MEAN_L_PATH = '/srv/data1/arunirc/Research/colorize-fcn/colorizer-fcn/logs/MODEL-fcn32s_color_CFG-014_VCS-db517d6_TIME-20171230-212406/mean_l.npy'
    dataset = \
        data_loader.ColorizeImageNet(
            root, split='train', set='tiny', bins='soft',
            gmm_path=GMM_PATH, mean_l_path=MEAN_L_PATH)
    print 'Test passed: test_init_gmm'
コード例 #6
0
def test_single_read_dimcheck():
    print 'Entering: test_single_read_dimcheck'
    dataset = data_loader.ColorizeImageNet(root, split='train', set='small')
    img, lbl = dataset.__getitem__(0)
    assert len(lbl) == 2
    im_hue = lbl[0].numpy()
    im_chroma = lbl[1].numpy()
    assert im_chroma.shape==im_hue.shape, \
            'Labels (Hue and Chroma maps) should have same dimensions.'
    print 'Test passed: test_single_read_dimcheck'
コード例 #7
0
def main():

    dataset = data_loader.ColorizeImageNet(data_root,
                                           split='val',
                                           set='small',
                                           bins='one-hot')

    if not osp.exists(exp_folder):
        os.makedirs(exp_folder)

    im_satval = []
    im_filenames = []
    for i in xrange(100):

        print i

        # Original RGB image
        img_file = dataset.files['val'][i]
        im_orig = PIL.Image.open(img_file)

        # Invalid image formats
        if len(im_orig.size) != 2:
            continue

        if len(im_orig.getbands()) != 3:
            continue

        im_orig = dataset.rescale(im_orig)
        im_hsv = skimage.color.rgb2hsv(im_orig)
        im_filenames.append(img_file)

        if method == 'saturation':
            im_satval.append(im_hsv[:, :, 1].mean())
        elif method == 'chroma':
            chroma = im_hsv[:, :, 1] * im_hsv[:, :, 2]
            im_satval.append(chroma.mean())
        elif method == 'rgbvar':
            rgb_var = im_orig.var(axis=2)
            im_satval.append(rgb_var.mean())
        else:
            raise ValueError

    # sort: greatest value first
    sorted_idx = np.argsort(im_satval)[::-1]

    for i in xrange(len(sorted_idx)):
        print im_satval[sorted_idx[i]]
        img_file = im_filenames[sorted_idx[i]]
        im_orig = skimage.io.imread(img_file)
        out_im_file = osp.join(exp_folder,
                        '{0:0{width}}'.format(i, width=6) + \
                        '_' + str(im_satval[sorted_idx[i]]) + '.jpg')
        skimage.io.imsave(out_im_file, im_orig)
コード例 #8
0
def test_grayscale_read():
    '''
        Handle single-channel images -- skip to previous image.
    '''
    print 'Entering: test_grayscale_read'
    dataset = data_loader.ColorizeImageNet(root, split='train', set='small')
    idx = 4606
    img_file = dataset.files['train'][idx]
    im1 = PIL.Image.open(img_file)
    im1 = np.asarray(im1, dtype=np.uint8)
    assert len(
        im1.shape) == 2, 'Check that selected image is indeed grayscale.'
    img, lbl = dataset.__getitem__(idx)
    print 'Test passed: test_grayscale_read'
コード例 #9
0
def test_cmyk_read():
    '''
        Handle CMYK images -- skip to previous image.
    '''
    print 'Entering: test_cmyk_read'
    dataset = data_loader.ColorizeImageNet(\
                root, split='train', set='small')
    idx = 44896
    img_file = dataset.files['train'][idx]
    im1 = PIL.Image.open(img_file)
    im1 = np.asarray(im1, dtype=np.uint8)
    assert im1.shape[2] == 4, 'Check that selected image is indeed CMYK.'
    img, lbl = dataset.__getitem__(idx)
    print 'Test passed: test_cmyk_read'
コード例 #10
0
def test_dataset_read():
    '''
        Read through the entire dataset.
    '''
    dataset = data_loader.ColorizeImageNet(\
                root, split='train', set='small')

    for i in xrange(len(dataset)):
        # if i > 44890: # HACK: skipping over some stuff
        img_file = dataset.files['train'][i]
        img, lbl = dataset.__getitem__(i)
        assert type(lbl) == torch.FloatTensor
        assert type(img) == torch.FloatTensor
        print 'iter: %d,\t file: %s,\t imsize: %s' % (i, img_file, img.size())
コード例 #11
0
def test_train_loader():
    print 'Entering: test_train_loader'
    train_loader = torch.utils.data.DataLoader(data_loader.ColorizeImageNet(
        root, split='train', set='small'),
                                               batch_size=1,
                                               shuffle=False)
    dataiter = iter(train_loader)
    img, label = dataiter.next()
    assert len(label)==2, \
        'Network should predict a 2-tuple: hue-map and chroma-map.'
    im_hue = label[0].numpy()
    im_chroma = label[1].numpy()
    assert im_chroma.shape==im_hue.shape, \
            'Labels (Hue and Chroma maps) should have same dimensions.'
    print 'Test passed: test_train_loader'
コード例 #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--exp_name', default='fcn32s_color')
    parser.add_argument('-g', '--gpu', type=int, required=True)
    parser.add_argument('-c',
                        '--config',
                        type=int,
                        default=1,
                        choices=configurations.keys())
    parser.add_argument('-b',
                        '--binning',
                        default='soft',
                        choices=('soft', 'one-hot', 'uniform'))
    parser.add_argument('-k', '--numbins', type=int, default=128)
    parser.add_argument(
        '-d',
        '--dataset_path',
        default='/vis/home/arunirc/data1/datasets/ImageNet/images/')
    parser.add_argument('-m', '--model_path', default=None)
    parser.add_argument('--resume', help='Checkpoint path')
    args = parser.parse_args()

    # -----------------------------------------------------------------------------
    # 0. setup
    # -----------------------------------------------------------------------------
    gpu = args.gpu
    cfg = configurations[args.config]
    cfg.update({'bin_type': args.binning, 'numbins': args.numbins})
    resume = args.resume
    if resume:
        out, _ = osp.split(resume)
    else:
        out = get_log_dir(args.exp_name, args.config, cfg, verbose=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)
        torch.backends.cudnn.enabled = True
        # torch.backends.cudnn.benchmark = True

    # -----------------------------------------------------------------------------
    # 1. dataset
    # -----------------------------------------------------------------------------
    # Custom dataset class defined in `data_loader.ColorizeImageNet`
    root = args.dataset_path
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}

    if 'img_lowpass' in cfg.keys():
        img_lowpass = cfg['img_lowpass']
    else:
        img_lowpass = None
    if 'train_set' in cfg.keys():
        train_set = cfg['train_set']
    else:
        train_set = 'full'
    if 'val_set' in cfg.keys():
        val_set = cfg['val_set']
    else:
        val_set = 'small'

    if 'gmm_path' in cfg.keys():
        gmm_path = cfg['gmm_path']
    else:
        gmm_path = None
    if 'mean_l_path' in cfg.keys():
        mean_l_path = cfg['mean_l_path']
    else:
        mean_l_path = None
    if 'im_size' in cfg.keys():
        im_size = cfg['im_size']
    else:
        im_size = (256, 256)
    if 'batch_size' in cfg.keys():
        batch_size = cfg['batch_size']
    else:
        batch_size = 1
    if 'uniform_sigma' in cfg.keys():
        uniform_sigma = cfg['uniform_sigma']
    else:
        uniform_sigma = 'default'
    if 'binning' in cfg.keys():
        args.binning = cfg['binning']

    # DEBUG: set='tiny'
    train_loader = torch.utils.data.DataLoader(
        data_loader.ColorizeImageNet(root,
                                     split='train',
                                     bins=args.binning,
                                     log_dir=out,
                                     num_hc_bins=args.numbins,
                                     set=train_set,
                                     img_lowpass=img_lowpass,
                                     im_size=im_size,
                                     gmm_path=gmm_path,
                                     mean_l_path=mean_l_path,
                                     uniform_sigma=uniform_sigma),
        batch_size=batch_size,
        shuffle=True,
        **kwargs)  # DEBUG: set shuffle False

    # DEBUG: set='tiny'
    val_loader = torch.utils.data.DataLoader(data_loader.ColorizeImageNet(
        root,
        split='val',
        bins=args.binning,
        log_dir=out,
        num_hc_bins=args.numbins,
        set=val_set,
        img_lowpass=img_lowpass,
        im_size=im_size,
        gmm_path=gmm_path,
        mean_l_path=mean_l_path,
        uniform_sigma=uniform_sigma),
                                             batch_size=1,
                                             shuffle=False,
                                             **kwargs)

    # -----------------------------------------------------------------------------
    # 2. model
    # -----------------------------------------------------------------------------
    model = models.FCN32sColor(n_class=args.numbins, bin_type=args.binning)
    if args.model_path:
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        pass
    if cuda:
        model = model.cuda()

    # -----------------------------------------------------------------------------
    # 3. optimizer
    # -----------------------------------------------------------------------------
    params = filter(lambda p: p.requires_grad, model.parameters())
    if 'optim' in cfg.keys():
        if cfg['optim'].lower() == 'sgd':
            optim = torch.optim.SGD(params,
                                    lr=cfg['lr'],
                                    momentum=cfg['momentum'],
                                    weight_decay=cfg['weight_decay'])
        elif cfg['optim'].lower() == 'adam':
            optim = torch.optim.Adam(params,
                                     lr=cfg['lr'],
                                     weight_decay=cfg['weight_decay'])
        else:
            raise NotImplementedError('Optimizers: SGD or Adam')
    else:
        optim = torch.optim.SGD(params,
                                lr=cfg['lr'],
                                momentum=cfg['momentum'],
                                weight_decay=cfg['weight_decay'])

    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])

    # -----------------------------------------------------------------------------
    # Sanity-check: forward pass with a single sample
    # -----------------------------------------------------------------------------
    DEBUG = False
    if DEBUG:
        dataiter = iter(val_loader)
        img, label = dataiter.next()
        model.eval()
        print 'Labels: ' + str(label.size())  # batchSize x num_class
        print 'Input: ' + str(img.size())  # batchSize x 1 x (im_size)
        if val_loader.dataset.bins == 'one-hot':
            from torch.autograd import Variable
            inputs = Variable(img)
            if cuda:
                inputs = inputs.cuda()
            outputs = model(inputs)
            assert len(outputs)==2, \
                'Network should predict a 2-tuple: hue-map and chroma-map.'
            hue_map = outputs[0].data
            chroma_map = outputs[1].data
            assert hue_map.size() == chroma_map.size(), \
                'Outputs should have same dimensions.'
            sz_h = hue_map.size()
            sz_im = img.size()
            assert sz_im[2]==sz_h[2] and sz_im[3]==sz_h[3], \
                'Spatial dims should match for input and output.'
        elif val_loader.dataset.bins == 'soft':
            from torch.autograd import Variable
            inputs = Variable(img)
            if cuda:
                inputs = inputs.cuda()
            outputs = model(inputs)
            # TODO: assertions
            # del inputs, outputs
        import pdb
        pdb.set_trace()  # breakpoint 0632fd52 //

        model.train()
    else:
        pass

    # -----------------------------------------------------------------------------
    # Training
    # -----------------------------------------------------------------------------
    trainer = train.Trainer(
        cuda=cuda,
        model=model,
        optimizer=optim,
        train_loader=train_loader,
        val_loader=val_loader,
        out=out,
        max_iter=cfg['max_iteration'],
        interval_validate=cfg.get('interval_validate', len(train_loader)),
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
コード例 #13
0
def main():

    # -----------------------------------------------------------------------------
    #   Setup
    # -----------------------------------------------------------------------------
    if binning == 'soft':
        dataset = data_loader.ColorizeImageNet(data_root,
                                               split='val',
                                               set='small',
                                               bins='soft',
                                               num_hc_bins=16,
                                               gmm_path=GMM_PATH,
                                               mean_l_path=MEAN_L_PATH)
        model = models.FCN8sColor(n_class=16, bin_type='soft')

    elif binning == 'uniform':
        dataset = data_loader.ColorizeImageNet(data_root,
                                               split=split,
                                               bins=binning,
                                               num_hc_bins=256,
                                               set=train_set,
                                               im_size=(256, 256),
                                               gmm_path=None,
                                               mean_l_path=MEAN_L_PATH,
                                               uniform_sigma='default')
        model = models.FCN8sColor(n_class=256, bin_type='uniform')

    else:
        raise NotImplementedError

    checkpoint = torch.load(MODEL_PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    if not osp.exists(osp.join(exp_folder, 'colorized-output-max-' + split)):
        os.makedirs(osp.join(exp_folder, 'colorized-output-max-' + split))

    # -----------------------------------------------------------------------------
    #   Colorize 100 images from dataset
    # -----------------------------------------------------------------------------
    for i in xrange(100):
        i = i + 10000
        print i

        input_im, labels = dataset.__getitem__(i)
        gmm = dataset.gmm
        mean_l = dataset.mean_l

        # Original RGB image
        img_file = dataset.files[split][i]
        im_orig = PIL.Image.open(img_file)
        if len(im_orig.size) != 2:
            continue

        if len(im_orig.getbands()) != 3:
            continue

        im_orig = dataset.rescale(im_orig)

        # "Ground-truth" colorization
        labels = labels.numpy()
        img = input_im.numpy().squeeze()
        im_rgb = utils.colorize_image_hc(labels, img, gmm, mean_l)

        # Get Hue-Chroma bin predictions from colorizer network
        input_im = input_im.unsqueeze(0)
        inputs = Variable(input_im)
        if cuda:
            model.cuda()
            inputs = inputs.cuda()

        outputs = model(inputs)
        outputs = F.softmax(outputs)
        preds = outputs.squeeze()
        preds = preds.permute(1, 2, 0)
        preds = preds.data.cpu().numpy()

        # Get a colorized image using predicted Hue-Chroma bins
        im_pred = utils.colorize_image_hc(preds,
                                          img,
                                          gmm,
                                          mean_l,
                                          method='max')

        tiled_img = np.concatenate(
            (im_orig, np.zeros([im_rgb.shape[0], 10, 3],
                               dtype=np.uint8), im_rgb,
             np.zeros([im_rgb.shape[0], 10, 3], dtype=np.uint8), im_pred),
            axis=1)

        out_im_file = osp.join(exp_folder, 'colorized-output-max-' + split,
                               str(i) + '.jpg')
        skimage.io.imsave(out_im_file, tiled_img)
コード例 #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpu', type=int, required=True)
    parser.add_argument('-c', '--config', type=int, default=15,
                        choices=configurations.keys())
    parser.add_argument('-b', '--binning', default='uniform', 
                        choices=('soft','one-hot', 'uniform'))
    parser.add_argument('-k', '--numbins', type=int, default=128)
    parser.add_argument('-d', '--dataset_path', 
                        default='/vis/home/arunirc/data1/datasets/ImageNet/images/')
    parser.add_argument('-m', '--model_path', default=None)
    parser.add_argument('--data_par', action='store_true', default=False, 
                        help='Use DataParallel for multi-gpu training')
    parser.add_argument('--resume', help='Checkpoint path')
    args = parser.parse_args()

    gpu = args.gpu
    # The config must specify path to pre-trained model as value for the 
    # key 'fcn16s_pretrained_model'
    cfg = configurations[args.config] 
    cfg.update({'bin_type':args.binning,'numbins':args.numbins})
    out = get_log_dir('fcn8s_color', args.config, cfg, verbose=False)
    resume = args.resume

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True    


    # -----------------------------------------------------------------------------
    # 1. dataset
    # -----------------------------------------------------------------------------
    # root = osp.expanduser('~/data/datasets')
    root = args.dataset_path
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    
    if 'img_lowpass' in cfg.keys():
        img_lowpass = cfg['img_lowpass']
    else:
        img_lowpass = None
    if 'train_set' in cfg.keys():
        train_set = cfg['train_set']
    else:
        train_set = 'full'
    if 'val_set' in cfg.keys():
        val_set = cfg['val_set']
    else:
        val_set = 'small'

    if 'gmm_path' in cfg.keys():
        gmm_path = cfg['gmm_path']
    else:
        gmm_path = None
    if 'mean_l_path' in cfg.keys():
        mean_l_path = cfg['mean_l_path']
    else:
        mean_l_path = None
    if 'im_size' in cfg.keys():
        im_size = cfg['im_size']
    else:
        im_size = (256, 256)
    if 'batch_size' in cfg.keys():
        batch_size = cfg['batch_size']
    else:
        batch_size = 1
    if 'uniform_sigma' in cfg.keys():
        uniform_sigma = cfg['uniform_sigma']
    else:
        uniform_sigma = 'default'
    if 'binning' in cfg.keys():
        args.binning = cfg['binning']
    
    # DEBUG: set='tiny'
    train_loader = torch.utils.data.DataLoader(
        data_loader.ColorizeImageNet(root, split='train', 
        bins=args.binning, log_dir=out, num_hc_bins=args.numbins, 
        set=train_set, img_lowpass=img_lowpass, im_size=im_size,
        gmm_path=gmm_path, mean_l_path=mean_l_path, uniform_sigma=uniform_sigma ),
        batch_size=24, shuffle=True, **kwargs) # DEBUG: set shuffle False

    # DEBUG: set='tiny'
    val_loader = torch.utils.data.DataLoader(
        data_loader.ColorizeImageNet(root, split='val', 
        bins=args.binning, log_dir=out, num_hc_bins=args.numbins, 
        set=val_set, img_lowpass=img_lowpass, im_size=im_size,
        gmm_path=gmm_path, mean_l_path=mean_l_path, uniform_sigma=uniform_sigma ),
        batch_size=1, shuffle=False, **kwargs)


    # -----------------------------------------------------------------------------
    # 2. model
    # -----------------------------------------------------------------------------
    model = models.FCN8sColor(n_class=args.numbins, bin_type=args.binning)

    if args.model_path:
        checkpoint = torch.load(args.model_path)        
        model.load_state_dict(checkpoint['model_state_dict'])
    else: 
        if resume:
            # HACK: takes very long ... better to start a new expt with init from `args.model_path`
            checkpoint = torch.load(resume)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch']
            start_iteration = checkpoint['iteration']
        else:
            fcn16s = models.FCN16sColor(n_class=args.numbins, bin_type=args.binning)
            fcn16s.load_state_dict(torch.load(cfg['fcn16s_pretrained_model'])['model_state_dict'])
            model.copy_params_from_fcn16s(fcn16s)

    start_epoch = 0
    start_iteration = 0

    if cuda:
        model = model.cuda()

    if args.data_par:    
        raise NotImplementedError    
        # model = torch.nn.DataParallel(model, device_ids=[1, 2, 3, 4, 5, 6])


    # -----------------------------------------------------------------------------
    # 3. optimizer
    # -----------------------------------------------------------------------------
    params = filter(lambda p: p.requires_grad, model.parameters())
    if 'optim' in cfg.keys():
    	if cfg['optim'].lower()=='sgd':
    		optim = torch.optim.SGD(params,
				        lr=cfg['lr'],
				        momentum=cfg['momentum'],
				        weight_decay=cfg['weight_decay'])
    	elif cfg['optim'].lower()=='adam':
    		optim = torch.optim.Adam(params,
				        lr=cfg['lr'], weight_decay=cfg['weight_decay'])
    	else:
    		raise NotImplementedError('Optimizers: SGD or Adam')
    else:
	    optim = torch.optim.SGD(params,
			        lr=cfg['lr'],
			        momentum=cfg['momentum'],
			        weight_decay=cfg['weight_decay'])

    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])


    # -----------------------------------------------------------------------------
    # Sanity-check: forward pass with a single sample
    # -----------------------------------------------------------------------------
    # dataiter = iter(val_loader)
    # img, label = dataiter.next()
    # model.eval()
    # if val_loader.dataset.bins == 'one-hot':
    #     from torch.autograd import Variable
    #     inputs = Variable(img)
    #     if cuda:
    #         inputs = inputs.cuda()
    #     outputs = model(inputs)
    #     assert len(outputs)==2, \
    #         'Network should predict a 2-tuple: hue-map and chroma-map.'
    #     hue_map = outputs[0].data
    #     chroma_map = outputs[1].data
    #     assert hue_map.size() == chroma_map.size(), \
    #         'Outputs should have same dimensions.'
    #     sz_h = hue_map.size()
    #     sz_im = img.size()
    #     assert sz_im[2]==sz_h[2] and sz_im[3]==sz_h[3], \
    #         'Spatial dims should match for input and output.'
    # elif val_loader.dataset.bins == 'soft':
    #     from torch.autograd import Variable
    #     inputs = Variable(img)
    #     if cuda:
    #     	inputs = inputs.cuda()
    #     outputs = model(inputs)
    #     # TODO: assertions
    #     # del inputs, outputs

    # model.train()    


    # -----------------------------------------------------------------------------
    # Training
    # -----------------------------------------------------------------------------
    trainer = train.Trainer(
        cuda=cuda,
        model=model,
        optimizer=optim,
        train_loader=train_loader,
        val_loader=val_loader,
        out=out,
        max_iter=cfg['max_iteration'],
        interval_validate=cfg.get('interval_validate', len(train_loader)),
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()