Example #1
0
def main(config: ConfigParser):
    # 获取一个logging.getLogger,默认日志级别为debug
    logger = config.get_logger('train')
    # 数据模块
    # 获取config中读取到的config.json里的loader的名字,并实例化,用json里的参数去填充
    data_loader = config.init_obj('data_loader', module_data)
    valid_data_loader = data_loader.split_validation()

    # 模型模块
    model = config.init_obj('arch', module_arch)
    logger.info(model)

    # 损失与评估模块
    criterion = getattr(module_loss, config['loss'])
    # 这里面存的是function,也可能存的是类,通过__name__方法获得名字
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # 优化器模块
    # filter,过滤掉false值
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
    # 学习率衰减策略
    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    # 训练模型
    trainer = Trainer(model, criterion, metrics, optimizer,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()
Example #2
0
def main(config: ConfigParser):
    logger = config.get_logger("train")

    # setup data_loader instances
    data_loader = config.init_obj("data_loader", module_data)
    valid_data_loader = data_loader.split_validation()

    # build model architecture, then print to console
    model = config.init_obj("arch", module_arch)
    logger.info(model)

    # get function handles of loss and metrics
    criterion = config.init_obj("criterion", module_criterion)
    metrics = [getattr(module_metric, met) for met in config["metrics"]]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj("optimizer", module_optim, trainable_params)

    lr_scheduler = config.init_obj("lr_scheduler", torch.optim.lr_scheduler,
                                   optimizer)

    trainer = Trainer(
        model,
        criterion,
        metrics,
        optimizer,
        config=config,
        data_loader=data_loader,
        valid_data_loader=valid_data_loader,
        lr_scheduler=lr_scheduler,
    )

    trainer.train()
Example #3
0
def main(cfg_dict: DictConfig):

    # TODO: erase previous logs in the folder at every run
    config = ConfigParser(cfg_dict)
    logger = config.get_logger('train')

    # setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)
    valid_data_loader = data_loader.split_validation()

    # build model architecture, then print to console
    model = config.init_obj('arch', module_arch)
    # logger.info(model)

    # get function handles of loss and metrics
    criterion = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj('optimizer', torch.optim, trainable_params)

    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler,
                                   optimizer)

    trainer = Trainer(model,
                      criterion,
                      metrics,
                      optimizer,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()
Example #4
0
def recognize_stroke(
    stroke,
    config_fn='../saved/models/Seq2SeqHandwritingRecognition/1114_091246/config.json',
    resume='../saved/models/Seq2SeqHandwritingRecognition/1114_091246/model_best.pth'
):
    # Parsing the config
    config = read_json(config_fn)
    config = ConfigParser(config, resume)
    # set up device and data_loader
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = module_data.HandWritingDataset('../data')
    # build model architecture and load weights
    model = config.init_obj('arch',
                            module_arch,
                            char2idx=dataset.char2idx,
                            device=device)
    checkpoint = torch.load(config.resume, map_location=device)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    # prepare model for inference
    model = model.to(device)
    model.eval()
    # Generation of unconditional handwriting
    with torch.no_grad():
        predicted_seq = model.recognize_sample(stroke)
        predicted_text = dataset.tensor2sentence(torch.tensor(predicted_seq))
    # Clean notebooks folder
    shutil.rmtree('saved')
    return predicted_text
def generate_conditionally(
    text,
    config_fn='../saved/models/ConditionalHandwriting/1114_101215/config.json',
    resume='../saved/models/ConditionalHandwriting/1114_101215/model_best.pth'
):
    # Parsing the config
    config = read_json(config_fn)
    config = ConfigParser(config, resume)
    # set up device and data_loader
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = HandWritingDataset('../data')
    # build model architecture and load weights
    model = config.init_obj('arch',
                            module_arch,
                            char2idx=dataset.char2idx,
                            device=device)
    checkpoint = torch.load(config.resume, map_location=device)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    # prepare model for inference
    model = model.to(device)
    model.eval()
    # Generation of unconditional handwriting
    with torch.no_grad():
        sampled_stroke = model.generate_conditional_sample(text)
    # Clean notebooks folder
    shutil.rmtree('saved')
    return sampled_stroke
Example #6
0
def main(config: ConfigParser, local_master: bool, logger=None):
    # setup dataset and data_loader instances
    train_dataset = config.init_obj('train_dataset', pick_dataset_module)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) \
        if config['distributed'] else None

    is_shuffle = False if config['distributed'] else True
    train_data_loader = config.init_obj('train_data_loader',
                                        torch.utils.data.dataloader,
                                        dataset=train_dataset,
                                        sampler=train_sampler,
                                        batch_size=8,
                                        shuffle=is_shuffle,
                                        collate_fn=BatchCollateFn())

    val_dataset = config.init_obj('validation_dataset', pick_dataset_module)
    val_data_loader = config.init_obj('val_data_loader',
                                      torch.utils.data.dataloader,
                                      dataset=val_dataset,
                                      collate_fn=BatchCollateFn())
    logger.info(
        f'Dataloader instances created. Batch size: {train_data_loader.batch_size} '
        f'Batch size: {val_data_loader.batch_size}.') if local_master else None
    logger.info(f'Train datasets: {len(train_dataset)} samples '
                f'Validation datasets: {len(val_dataset)} samples.'
                ) if local_master else None

    # build model architecture
    pick_model = config.init_obj('model_arch', pick_arch_module)
    logger.info(
        f'Model created, trainable parameters: {pick_model.model_parameters()}.'
    ) if local_master else None

    # build optimizer, learning rate scheduler.
    optimizer = config.init_obj('optimizer', torch.optim,
                                pick_model.parameters())
    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler,
                                   optimizer)
    logger.info(
        'Optimizer and lr_scheduler created.') if local_master else None

    # print training related information
    logger.info(
        'Max_epochs: {} Log_per_step: {} Validation_per_step: {}.'.format(
            config['trainer']['epochs'],
            config['trainer']['log_step_interval'],
            config['trainer']['val_step_interval'])) if local_master else None

    logger.info('Training start...') if local_master else None
    trainer = Trainer(pick_model,
                      optimizer,
                      config=config,
                      data_loader=train_data_loader,
                      valid_data_loader=val_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()
    logger.info('Training end...') if local_master else None
Example #7
0
def load_trained_model_by_path(checkpoint_path, config):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    loaded_epoch = checkpoint['epoch']

    print('loaded', checkpoint_path, 'from epoch', loaded_epoch)
    # Load model with parameters from config file
    config_parser = ConfigParser(config, dry_run=True)
    model = config_parser.init_obj('arch', module_arch)

    # TODO: WARNING: Leaving some mipmap layer weights unassigned might lead to erroneous
    #  results (maybe they're not set to zero by default)
    # Assign model weights and set to eval (not train) mode
    #model.load_state_dict(checkpoint['state_dict'], strict=(not zero_other_mipmaps))
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    return model, loaded_epoch
Example #8
0
def main(cfg_dict: DictConfig):

    config = ConfigParser(cfg_dict)
    logger = config.get_logger('test')

    # setup data_loader instances
    data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=512,
        shuffle=False,
        validation_split=0.0,
        training=False,
        num_workers=2)

    # build model architecture
    model = config.init_obj('arch', module_arch)
    logger.info(model)

    # get function handles of loss and metrics
    loss_fn = getattr(module_loss, config['loss'])
    metric_fns = [getattr(module_metric, met) for met in config['metrics']]

    logger.info('Loading checkpoint: {} ...'.format(config['resume']))
    checkpoint = torch.load(config['resume'])
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # prepare model for testing
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    total_loss = 0.0
    total_metrics = torch.zeros(len(metric_fns))

    with torch.no_grad():
        for i, (data, target) in enumerate(tqdm(data_loader)):

            # TODO: overlap objects with overlap_objects_from_batch in util.oy
            # TODO: check model's output is correct for the loss_fn

            data, target = data.to(device), target.to(device)
            output = model(data)

            #
            # save sample images, or do something with output here
            #
            # computing loss, metrics on test set
            loss = loss_fn(output, target)
            batch_size = data.shape[0]
            total_loss += loss.item() * batch_size
            for i, metric in enumerate(metric_fns):
                total_metrics[i] += metric(output, target) * batch_size

    n_samples = len(data_loader.sampler)
    log = {'loss': total_loss / n_samples}
    log.update({
        met.__name__: total_metrics[i].item() / n_samples
        for i, met in enumerate(metric_fns)
    })
    logger.info(log)
Example #9
0
    def pred(self, paths, metas, m_cfg, id):
        print('pred')
        self.cfg = m_cfg
        res = Response()
        if len(paths) != len(metas):
            res.code = -2
            res.msg = "The length of images and meta is not same."
            return res
        # if self.pred_th is not None:
        #     if self.pred_th.is_alive():
        #         res.code = -3
        #         res.msg = "There is a task running, please wait it finish."
        #         return res
        try:
            m_typename = m_cfg["name"].split("-")[1]
            if m_typename == "Deeplab" or m_typename == "UNet":
                from .predthread import SegPredThread
                self.device = torch.device(
                    'cuda:0' if self.n_gpu_use > 0 else 'cpu')
                torch.set_grad_enabled(False)
                m_cfg["save_dir"] = str(self.tmp_path)
                config = ConfigParser(m_cfg, Path(m_cfg["path"]))
                self.logger = config.get_logger('PredServer')
                self.model = config.init_obj('arch', module_arch)
                self.logger.info('Loading checkpoint: {} ...'.format(
                    config.resume))
                if self.n_gpu_use > 1:
                    self.model = torch.nn.DataParallel(self.model)
                if self.n_gpu_use > 0:
                    checkpoint = torch.load(config.resume)
                else:
                    checkpoint = torch.load(config.resume,
                                            map_location=torch.device('cpu'))

                state_dict = checkpoint['state_dict']
                self.model.load_state_dict(state_dict)
                self.model = self.model.to(self.device)
                self.model.eval()

                if "crop_size" in config["tester"]:
                    self.crop_size = config["tester"]["crop_size"]

                if 'postprocessor' in config["tester"]:
                    module_name = config["tester"]['postprocessor']['type']
                    module_args = dict(
                        config["tester"]['postprocessor']['args'])
                    self.postprocessor = getattr(postps_crf,
                                                 module_name)(**module_args)

                self.tmp_path.mkdir(parents=True, exist_ok=True)

                self.pred_ths.append(
                    SegPredThread(self, paths, metas, self.tmp_path, id))
            elif m_typename == "CycleGAN":
                from .predthread import CycleGANPredThread
                from model import CycleGANOptions, CycleGANModel
                # config = ConfigParser(m_cfg, Path(m_cfg["path"]))
                opt = CycleGANOptions(**m_cfg["arch"]["args"])
                opt.batch_size = self.batch_size
                opt.serial_batches = True
                opt.no_flip = True  # no flip;
                opt.display_id = -1  # no visdom display; the test code saves the results to a HTML file.
                opt.isTrain = False
                opt.gpu_ids = []
                for i in range(0, self.n_gpu_use):
                    opt.gpu_ids.append(i)
                opt.checkpoints_dir = str(self.tmp_path)
                opt.preprocess = "none"
                opt.direction = 'AtoB'
                self.model = CycleGANModel(opt)

                orig_save_dir = self.model.save_dir
                self.model.save_dir = ""
                self.model.load_networks(m_cfg["path"])
                self.model.save_dir = orig_save_dir
                torch.set_grad_enabled(False)
                self.model.set_requires_grad(
                    [self.model.netG_A, self.model.netG_B], False)

                self.pred_ths.append(
                    CycleGANPredThread(self, paths, metas, self.tmp_path, id))
            else:
                raise NotImplementedError("Model type:", m_typename,
                                          "is not supported.")

            print('NotifyStartThread')
            self.pred_ths[-1].start()
            # self.pred_th.is_alive()
        except Exception as e:
            res.code = -1
            res.msg = str(e)
            return res

        res.code = 0
        res.msg = "Success"
        return res
Example #10
0
def main(cfg_dict : DictConfig):
    generate = False
    load_gen = True
    save = True
    # remove_eigs = True
    remove_eigs = False

    config = ConfigParser(cfg_dict)
    T_rec, T_pred = config['n_timesteps'], config['seq_length'] - config['n_timesteps']
    logger = config.get_logger('test')

    gt = True
    # gt = True
    model_name = 'ddpae-iccv'
    # model_name = 'DRNET'
    # model_name = 'scalor'
    # model_name = 'sqair'
    s_directory = os.path.join(config['data_loader']['args']['data_dir'], 'test_data')
    res_directory = os.path.join(config['data_loader']['args']['data_dir'], 'res_data')
    load_gen_directory = os.path.join(config['data_loader']['args']['data_dir'],
                                      'results')
    # # TODO: Testing features
    # load_gen_directory = os.path.join(config['data_loader']['args']['data_dir'], 'test_data')

    if not os.path.exists(s_directory):
        os.makedirs(s_directory)
    if not os.path.exists(res_directory):
        os.makedirs(res_directory)
    dataset_dir = os.path.join(s_directory, config['data_loader']['args']['dataset_case']+
                               '_Len-'+str(config['seq_length'])+'_Nts-'+str(config['n_timesteps'])+'.npy')
    results_dir = os.path.join(res_directory, config['data_loader']['args']['dataset_case']+
                               '_Len-'+str(config['seq_length'])+'_Nts-'+str(config['n_timesteps'])+'.npz')
    all_data = []
    if not os.path.exists(dataset_dir) and generate:
        config['data_loader']['args']['shuffle'] = False
        config['data_loader']['args']['training'] = False
        config['data_loader']['args']['validation_split'] = 0.0
        data_loader = config.init_obj('data_loader', module_data)

        for i, data in enumerate(tqdm(data_loader)):
            all_data.append(data)
        all_data = torch.cat(all_data, dim=0).numpy()
        print(all_data.shape)
        np.save(dataset_dir, all_data)
        print(config['data_loader']['args']['dataset_case']+ ' data generated in: '+s_directory)
        exit()
    if os.path.exists(dataset_dir):
        print('LOADING EXISTING DATA FROM: ' + dataset_dir)
        inps = torch.from_numpy(np.load(dataset_dir))
        if os.path.exists(load_gen_directory) and load_gen:
            if model_name == 'ddpae-iccv':
                outs = torch.from_numpy(
                    np.load(os.path.join(
                        load_gen_directory,
                        model_name +'--'+config['data_loader']['args']['dataset_case']+
                        '_Len-'+str(config['seq_length'])+'_Nts-'+str(config['n_timesteps'])+'.npy')))

            else:
                with np.load(os.path.join(
                        load_gen_directory,
                        model_name +'_'+config['data_loader']['args']['dataset_case']+'.npz')) as outputs:
                    if model_name == 'scalor':
                        outs = torch.from_numpy(outputs["pred"]).permute(0,1,3,2).unsqueeze(2)
                    elif model_name == 'DRNET':
                        outs = torch.from_numpy(outputs["pred"]).unsqueeze(2).float()
                    else:
                        outs = torch.from_numpy(outputs["pred"]).unsqueeze(2)
                    print('Inps and Outs shapes', inps.shape, outs.shape)
            loaded_dataset = TensorDataset(inps, outs)
        else:
            loaded_dataset = TensorDataset(inps)
        data_loader = DataLoader(loaded_dataset, batch_size=40, shuffle=False, sampler=None,
                            batch_sampler=None, num_workers=2, collate_fn=None,
                            pin_memory=False)
    else:
        print('te has liao si te metes aqui')
        exit()
        config['data_loader']['args']['shuffle'] = False
        config['data_loader']['args']['training'] = False
        config['data_loader']['args']['validation_split'] = 0.0
        data_loader = config.init_obj('data_loader', module_data)
    # build model architecture
    if not load_gen:
        model = config.init_obj('arch', module_arch)
    # logger.info(model)

    # get function handles of loss and metrics
    loss_fn = getattr(module_loss, config['loss'])
    metric_fns = [getattr(module_metric, met) for met in ["mse", "mae", "bce", "mssim", "mlpips"]]

    if not load_gen:
        logger.info('Loading checkpoint: {} ...'.format(config['resume']))
        checkpoint = torch.load(config['resume'])
        state_dict = checkpoint['state_dict']
        if config['n_gpu'] > 1:
            model = torch.nn.DataParallel(model)
        model.load_state_dict(state_dict)

        # prepare model for testing
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model.eval()

        if remove_eigs:
            A_modified, indices, e = remove_eig_under_t(
                model.koopman.dynamics.dynamics.weight.data, t=0.7)
            A_modified = torch.from_numpy(A_modified.real).to(device)
            model.koopman.dynamics.dynamics.weight.data = A_modified

    total_loss = 0.0


    total_metrics = [torch.zeros(len(metric_fns)), torch.zeros(len(metric_fns))]

    # TODO: Here we can change the model's K, and crop the eigenvalues under certain module threshold.
    # Si la nova prediccio es mes llarga, evaluem nomes la nova:
    # T_pred = 8
    all_pred, all_rec = [], []
    with torch.no_grad():
        for i, data in enumerate(tqdm(data_loader)):
            if isinstance(data, list) and len(data) == 2:
                target = data[0]
                output = data[1]
                batch_size = target.shape[0]
                # total_loss += loss.item() * batch_size
                pred = output[:, -T_pred:], target[:, -T_pred:]
                rec = output[:, :T_rec], target[:, :T_rec]

                assert T_rec + T_pred == target.shape[1]
                assert target.shape == output.shape
            else:
                if isinstance(data, list) and len(data) == 1:
                    data = data[0]
                # if config["data_loader"]["type"] == "MovingMNISTLoader":
                #     data = overlap_objects_from_batch(data,config['n_objects'])
                target = data # Is data a variable?
                data, target = data.to(device), target.to(device)

                output = model(data, epoch_iter=[-1], test=True)
                # computing loss, metrics on test set
                # loss, loss_particles = loss_fn(output, target,
                #                                epoch_iter=[-1],
                #                                case=config["data_loader"]["args"]["dataset_case"])
                batch_size = data.shape[0]
                # total_loss += loss.item() * batch_size

                pred = output["pred_roll"][:, -T_pred:] , target[:, -T_pred:] #* 0.85
                rec = output["rec_ori"][:, :T_rec] * 0.85, target[:, :T_rec]

                assert T_rec + T_pred == target.shape[1]

            if config['data_loader']['args']['dataset_case'] == 'circles_crop':
                rec_cr, pred_cr = [crop_top_left_keepdim(vid[0], 35) for vid in [rec, pred]]
                rec, pred = (rec_cr, target[:, :T_rec]), (pred_cr, target[:, -T_pred:])

            # Save image sample
            if i==0:
                if gt:
                    idx_gt = 1
                else:
                    idx_gt = 0
                # 11 fail to reconstruct.
                idx = 21
                # print(rec.shape, pred.shape)
                # print_u = output["u"].reshape(40, 2, -1, 4)[idx,:,-torch.cat(pred, dim=-2).shape[1]:]\
                #     .cpu()
                # print_u = print_u.abs()*255
                # print_im = torch.cat(pred, dim=-2).permute(0,2,3,1,4)[idx,0,:,:]
                print_im = pred[idx_gt].permute(0,2,3,1,4)[idx,0]
                np.save("/home/acomasma/ool-dynamics/dk/image_sample.npy", print_im.cpu().numpy())
                image = im.fromarray(print_im.reshape(print_im.shape[-3], -1).cpu().numpy()*255)
                image = image.convert('RGB')
                image.save("/home/acomasma/ool-dynamics/dk/image_sample.png")

                # u_plot_o1 = im.fromarray(plot_matrix(print_u[0]).permute(1,0).numpy()).convert('RGB')
                # u_plot_o1.save("/home/acomasma/ool-dynamics/dk/input_sample_o1.png")
                #
                # u_plot_o2 = im.fromarray(plot_matrix(print_u[1]).permute(1,0).numpy()).convert('RGB')
                # u_plot_o2.save("/home/acomasma/ool-dynamics/dk/input_sample_o2.png")
                # exit()
                image = im.fromarray(rec[idx_gt].permute(0,2,3,1,4)[idx,0].reshape(64, -1).cpu().numpy()*255)
                image = image.convert('RGB')
                image.save("/home/acomasma/ool-dynamics/dk/image_sample_rec.png")
                exit()

            all_pred.append(pred[0])
            all_rec.append(rec[0])

            for j, (out, tar) in enumerate([rec, pred]):
                for i, metric in enumerate(metric_fns):
                    # TODO: dataset case in metrics
                    total_metrics[j][i] += metric(out, tar) * batch_size

    n_samples = len(data_loader.sampler)
    print('n_samples', n_samples)
    # log = {'loss': total_loss / n_samples}
    log = {}

    print('Timesteps Rec and pred: ' , T_rec, T_pred)
    for j, name in enumerate(['rec', 'pred']):
        log.update({
            met.__name__: total_metrics[j][i].item() / n_samples for i, met in enumerate(metric_fns)
        })
        print(name)
        logger.info(log)
Example #11
0
def main(config: ConfigParser, local_master: bool, logger=None):
    train_batch_size = config['trainer']['train_batch_size']
    val_batch_size = config['trainer']['val_batch_size']

    train_num_workers = config['trainer']['train_num_workers']
    val_num_workers = config['trainer']['val_num_workers']

    # setup  dataset and data_loader instances
    img_w = config['train_dataset']['args']['img_w']
    img_h = config['train_dataset']['args']['img_h']
    in_channels = config['model_arch']['args']['backbone_kwargs'][
        'in_channels']
    convert_to_gray = False if in_channels == 3 else True
    train_dataset = config.init_obj('train_dataset',
                                    master_dataset,
                                    transform=ResizeWeight(
                                        (img_w, img_h),
                                        gray_format=convert_to_gray),
                                    convert_to_gray=convert_to_gray)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) \
        if config['distributed'] else None

    is_shuffle = False if config['distributed'] else True
    train_data_loader = config.init_obj(
        'train_loader',
        torch.utils.data.dataloader,
        dataset=train_dataset,
        sampler=train_sampler,
        batch_size=train_batch_size,
        collate_fn=DistCollateFn(training=True),
        num_workers=train_num_workers,
        shuffle=is_shuffle)

    val_dataset = config.init_obj('val_dataset',
                                  master_dataset,
                                  transform=ResizeWeight(
                                      (img_w, img_h),
                                      gray_format=convert_to_gray),
                                  convert_to_gray=convert_to_gray)
    val_sampler = DistValSampler(list(range(len(val_dataset))),
                                 batch_size=val_batch_size,
                                 distributed=config['distributed'])
    val_data_loader = config.init_obj('val_loader',
                                      torch.utils.data.dataloader,
                                      dataset=val_dataset,
                                      batch_sampler=val_sampler,
                                      batch_size=1,
                                      collate_fn=DistCollateFn(training=True),
                                      num_workers=val_num_workers)

    logger.info(
        f'Dataloader instances have finished. Train datasets: {len(train_dataset)} '
        f'Val datasets: {len(val_dataset)} Train_batch_size/gpu: {train_batch_size} '
        f'Val_batch_size/gpu: {val_batch_size}.') if local_master else None

    max_len_step = len(train_data_loader)
    if config['trainer']['max_len_step'] is not None:
        max_len_step = min(config['trainer']['max_len_step'], max_len_step)

    # build model architecture
    model = config.init_obj('model_arch', master_arch)
    logger.info(
        f'Model created, trainable parameters: {model.model_parameters()}.'
    ) if local_master else None

    # build optimizer, learning rate scheduler.
    optimizer = config.init_obj('optimizer', torch.optim, model.parameters())
    if config['lr_scheduler']['type'] is not None:
        lr_scheduler = config.init_obj('lr_scheduler',
                                       torch.optim.lr_scheduler, optimizer)
    else:
        lr_scheduler = None
    logger.info(
        'Optimizer and lr_scheduler created.') if local_master else None

    # log training related information
    logger.info(
        'Max_epochs: {} Log_step_interval: {} Validation_step_interval: {}.'.
        format(
            config['trainer']['epochs'],
            config['trainer']['log_step_interval'],
            config['trainer']['val_step_interval'])) if local_master else None

    logger.info('Training start...') if local_master else None

    trainer = Trainer(model,
                      optimizer,
                      config,
                      data_loader=train_data_loader,
                      valid_data_loader=val_data_loader,
                      lr_scheduler=lr_scheduler,
                      max_len_step=max_len_step)

    trainer.train()

    logger.info('Distributed training end...') if local_master else None