Ejemplo n.º 1
0
def main():

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)

    _, test_loader, _ = build_dataset(
        dataset=dataset, batch_size=args.batch_size, input_dir=args.input_dir
    )

    torch_device = torch.device("cuda")
    checkpointer = Checkpointer()

    model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
    model = model.to(torch_device)
    model, _ = mixed_precision.initialize(model, None)

    test_stats = AverageMeterSet()
    test(model, test_loader, torch_device, test_stats)
    stat_str = test_stats.pretty_string(ignore=model.tasks)
    print(stat_str)
Ejemplo n.º 2
0
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    # enable mixed-precision computation if desired
    amp = ""
    if args.amp:
        amp = "torch"
        if args.apex:
            print("Error: Cannot use both --amp and --apex.")
            exit()

    if args.apex:
        amp = "apex"
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args.output_dir, args.run_name)
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args.batch_size,
                      input_dir=args.input_dir,
                      labeled_only=args.classifiers)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(args.output_dir)
    if args.cpt_load_path:
        model = checkpointer.restore_model_from_checkpoint(
            args.cpt_load_path, training_classifier=args.classifiers)
    else:
        # create new model with random parameters
        model = Model(ndf=args.ndf,
                      n_classes=num_classes,
                      n_rkhs=args.n_rkhs,
                      tclip=args.tclip,
                      n_depth=args.n_depth,
                      encoder_size=encoder_size,
                      use_bn=(args.use_bn == 1))
        model.init_weights(init_scale=1.0)
        checkpointer.track_new_model(model)

    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args.classifiers else train_self_supervised
    task(model, args.learning_rate, dataset, train_loader, test_loader,
         stat_tracker, checkpointer, args.output_dir, torch_device, amp)
Ejemplo n.º 3
0
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args['output_dir']):
        os.mkdir(args['output_dir'])

    # enable mixed-precision computation if desired
    if args['amp']:
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args['seed'])
    torch.cuda.manual_seed(args['seed'])

    # get the dataset
    dataset = get_dataset(args['dataset'])
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args['output_dir'], args['run_name'])
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args['batch_size'],
                      input_dir=args['input_dir'],
                      labeled_only=args['classifiers'])

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(args['output_dir'])
    if args['cpt_load_path']:
        model = checkpointer.restore_model_from_checkpoint(
            args['cpt_load_path'], training_classifier=args['classifiers'])
    else:
        # create new model with random parameters
        model = Model(ndf=args['ndf'],
                      n_classes=num_classes,
                      n_rkhs=args['n_rkhs'],
                      tclip=args['tclip'],
                      n_depth=args['n_depth'],
                      encoder_size=encoder_size,
                      use_bn=(args['use_bn'] == 1))
        model.init_weights(init_scale=1.0)
        checkpointer.track_new_model(model)

    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args['classifiers'] else train_self_supervised
    if args['classifiers']:
        task = train_classifiers
    elif args['decoder']:
        task = train_decoder
    else:
        task = train_self_supervised

    task(model, args['learning_rate'], dataset, train_loader, test_loader,
         stat_tracker, checkpointer, args['output_dir'], torch_device)
Ejemplo n.º 4
0
def do_test(cfg, model, dataloader, logger, task, load_ckpt):
    if load_ckpt is not None:
        checkpointer = Checkpointer(model,
                                    save_dir=cfg.OUTPUT_DIR,
                                    logger=logger,
                                    monitor_unit='episode')
        checkpointer.load_checkpoint(load_ckpt)

    val_metrics = model.metric_evaluator

    model.eval()
    num_images = 0
    meters = MetricLogger(delimiter="  ")

    logger.info('Start testing...')
    start_testing_time = time.time()
    end = time.time()
    for iteration, data in enumerate(dataloader):
        data_time = time.time() - end
        inputs, labels = torch.cat(
            data['images']).to(device), data['labels'].to(device)
        logits = model(inputs)
        val_metrics.accumulated_update(logits, labels)
        num_images += logits.shape[0]

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)
        eta_seconds = meters.time.global_avg * (len(dataloader) - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        if iteration % 50 == 0 and iteration > 0:
            logger.info('eta: {}, iter: {}/{}'.format(eta_string, iteration,
                                                      len(dataloader)))

    val_metrics.gather_results()
    logger.info('num of images: {}'.format(num_images))
    logger.info('{} top1 acc: {:.4f}'.format(
        task, val_metrics.accumulated_topk_corrects['top1_acc']))

    total = time.time() - start_testing_time
    total_time_str = str(datetime.timedelta(seconds=total))

    logger.info("Total testing time: {}".format(total_time_str))
    return val_metrics
Ejemplo n.º 5
0
def obtain_model(model_type):
    if model_type != 'robust':
        checkpoint_path = 'runs/amdim_cpt.pth'
        checkpointer = Checkpointer()
        print('Loading model')
        model = checkpointer.restore_model_from_checkpoint(checkpoint_path)
        torch_device = torch.device('cuda')
        model = model.to(torch_device)
    else:
        dataset = robustness.datasets.CIFAR()
        model_kwargs = {
            'arch':
            'resnet50',
            'dataset':
            dataset,
            'resume_path':
            f'../robust_classif/robustness_applications/models/CIFAR.pt'
        }
        model, _ = model_utils.make_and_restore_model(**model_kwargs)
    model.eval()
    model = CommonModel(model, model_type)
    return model
Ejemplo n.º 6
0
def do_test(cfg, model, dataloader, logger, task, load_ckpt):
    checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR, logger=logger)
    checkpointer.load_checkpoint(load_ckpt)
    val_metrics = TransferNetMetrics(cfg)

    model.eval()
    num_images = 0
    logger.info('Start testing...')
    for iteration, data in enumerate(dataloader):
        inputs, targets = data['image'].to(device), data['label_articleType'].to(device)
        cls_scores = model(inputs, targets)
        val_metrics.accumulated_update(cls_scores, targets)
        num_images += len(targets)

    val_metrics.gather_results()
    logger.info('num of images: {}'.format(num_images))
    logger.info('{} top1/5 acc: {:.4f}/{:.4f}, mean class acc: {:.4f}'.format(task,
                                                                              val_metrics.accumulated_topk_corrects[
                                                                                  'top1_acc'],
                                                                              val_metrics.accumulated_topk_corrects[
                                                                                  'top5_acc'],
                                                                              val_metrics.mean_class_accuracy))
    return val_metrics
Ejemplo n.º 7
0
def evaluateVideos(seqs, tag):
    dh = DeviceHelper()
    print('Using ' + str(dh.device))
    cp = Checkpointer(params.checkpoint_path, dh.device)
    model, _ = Driver.CreateModelAndOpt(params, dh, cp)

    for seq_info in seqs:
        seq, _, _ = seq_info
        print('Processing ' + seq)
        s = SingleVideo(params.image_dir, params.pose_dir, seq_info)
        ds = DeviceDataset(s)
        output = os.path.join(params.pred_dir, tag + '_' + seq + '.npy')
        Driver.evalOnVideo(model, ds, output)

    seq_list = [x for x, _, _ in seqs]

    Plotting.plot(params.pose_dir, params.pred_dir, tag + '_',
                  params.result_dir, seq_list)
Ejemplo n.º 8
0
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        img = self._load_image(item['image'])
        # convert 3D to 2D tensor to store in kaldi-format
        uttid = item["uttid"]
        return {"uttid": uttid, "image": img}


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    # Restore model from the checkpoint
    ckpt = Checkpointer()
    ckpt.restore_model_from_checkpoint(
        cpt_path="amdim_ndf256_rkhs2048_rd10.pth")
    ckpt.model.to('cuda')
    img_tmp_ark = os.path.splitext(args.img_as_feats_scp)[0] + '.tmp.ark'
    ds = ImageDataset(args.places_json)
    with kio.open_or_fd(img_tmp_ark, 'wb') as f:
        for i in tqdm(range(len(ds))):
            item = ds[i]
            feats = item["image"]
            batch = torch.zeros(2, 3, 128, 128)
            batch[0] = feats
            batch = batch.to('cuda')
            res_dict = ckpt.model(x1=batch, x2=batch, class_only=True)
            global_feats = res_dict["rkhs_glb"][:1]
            k = item["uttid"]
Ejemplo n.º 9
0
from neat import config, population, chromosome, genome, visualize
from neat.nn import nn_pure as nn
from checkpoint import Checkpointer
#from player_2048 import *

import sys

pop = Checkpointer.restore_checkpoint(sys.argv[1])

winner = pop.stats[0][-1]
print('Number of evaluations: %d' % winner.id)

# Visualize the winner network (requires PyDot)
#visualize.draw_net(winner) # best chromosome

# Plots the evolution of the best/average fitness (requires Biggles)
#visualize.plot_stats(pop.stats)
# Visualizes speciation
#visualize.plot_species(pop.species_log)

print pop.species_log

# Let's check if it's really solved the problem
print('\nBest network output:')
brain = nn.create_ffphenotype(winner)
score = play_with_AI(brain)
print 'score do melhor: ' + str(int(score))
Ejemplo n.º 10
0
from driver import *
from parameters import *

from device import DeviceHelper
from checkpoint import Checkpointer

from dataset import VideoDataset

dh = DeviceHelper()
cp = Checkpointer(params.checkpoint_path, dh.device)

# load model and optimizer
model, opt = Driver.CreateModelAndOpt(params, dh, cp)

# run training
history = Driver.fit(params.epochs, model, opt, params, cp)

# save history and model to checkpoint
with open(params.log_file, 'a') as of:
    of.write('*' * 50)
    of.write('Loss sequence')
    of.write(str(history))

epoch, loss = zip(*history)

cp.CreateCheckpoint(model, opt, params.epochs, min(loss),
                    params.checkpoint_tag, params.epochs)
Ejemplo n.º 11
0
	if torch.cuda.device_count() > 1:
		model = nn.DataParallel(model)
		data_parallel = True

	# 分配模型到gpu或cpu,根据device决定
	model.to(device)

	#优化器
	optimizer = torch.optim.Adam(model.parameters(), lr=lr)
	
	# 学习率衰减策略,一半的时候衰减为十分之一
	scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)

	# 判断是否有保存的模型,有的话加载最后一个继续训练
    checkpointer = Checkpointer(
        model, optimizer, scheduler, pths_path
    )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    start_epoch = arguments['iteration'] # 开始的轮数

    logger.info('start_epoch is :{}'.format(start_epoch))

	for epoch in range(start_epoch, epoch_iter):
		iteration = epoch + 1
        arguments['iteration'] = iteration	
		model.train()
		epoch_loss = 0 # 初始化每一轮的损失为0
		epoch_time = time.time() # 记录每一轮的时间
		for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
Ejemplo n.º 12
0
    num_workers = model_conf['train_data_loader_kwargs']['num_workers']
    lr = model_conf['trainer_kwargs']['opt_kwargs']['lr']

    #update run id in config file
    model_conf['run_id'] = run_id
    json.dump(model_conf, open('config.json', 'w'))

    #Construct DataLoader and checkpointer
    train_loader = Image_Loader(img_path,
                                batch_size=batch_size,
                                shuffle=True,
                                drop_last=True,
                                num_workers=num_workers,
                                input_shape=input_shape,
                                stage='train')
    checkpointer = Checkpointer(run=run)

    # Load checkpoint if given, otherwise construct a new model
    encoder, mi_estimator = checkpointer.restore_model_from_checkpoint()

    # Compute on multiple GPUs, if there are more than one given
    if torch.cuda.device_count() > 1:
        print("Let's use %d GPUs" % torch.cuda.device_count())
        encoder = torch.nn.DataParallel(encoder).module
        mi_estimator = torch.nn.DataParallel(mi_estimator).module
    encoder.to(device)
    mi_estimator.to(device)

    enc_optim = torch.optim.Adam(encoder.parameters(), lr=lr)
    mi_optim = torch.optim.Adam(mi_estimator.parameters(), lr=lr)
    try:
Ejemplo n.º 13
0
def do_train(cfg, model, train_dataloader, logger, load_ckpt=None):
    # define optimizer
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=cfg.TRAIN.LR_BASE,
                              momentum=cfg.TRAIN.MOMENTUM,
                              weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        raise NotImplementedError

    # define learning rate scheduler
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                    cfg.TRAIN.LR_DECAY)
    checkpointer = Checkpointer(model,
                                optimizer,
                                lr_scheduler,
                                cfg.OUTPUT_DIR,
                                logger,
                                monitor_unit='episode')

    training_args = {}
    # training_args['iteration'] = 1
    training_args['episode'] = 1
    if load_ckpt:
        checkpointer.load_checkpoint(load_ckpt, strict=False)

    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load()
        training_args.update(extra_checkpoint_data)

    start_episode = training_args['episode']
    episode = training_args['episode']

    meters = MetricLogger(delimiter="  ")
    end = time.time()
    start_training_time = time.time()

    model.train()
    break_while = False
    while not break_while:
        for inner_iter, data in enumerate(train_dataloader):
            training_args['episode'] = episode
            data_time = time.time() - end
            # targets = torch.cat(data['labels']).to(device)
            inputs = torch.cat(data['images']).to(device)
            logits = model(inputs)
            losses = model.loss_evaluator(logits)
            metrics = model.metric_evaluator(logits)

            total_loss = sum(loss for loss in losses.values())
            meters.update(loss=total_loss, **losses, **metrics)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0:
                eta_seconds = meters.time.global_avg * (cfg.TRAIN.MAX_EPISODE -
                                                        episode)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "episode: {ep}/{max_ep}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        ep=episode,
                        max_ep=cfg.TRAIN.MAX_EPISODE,
                        iter=inner_iter,
                        max_iter=len(train_dataloader),
                        meters=str(meters),
                        lr=optimizer.param_groups[-1]["lr"],
                        memory=(torch.cuda.max_memory_allocated() / 1024.0 /
                                1024.0) if torch.cuda.is_available() else 0.,
                    ))

            if episode % cfg.TRAIN.LR_DECAY_EPISODE == 0:
                logger.info("lr decayed to {:.4f}".format(
                    optimizer.param_groups[-1]["lr"]))
                lr_scheduler.step()

            if episode == cfg.TRAIN.MAX_EPISODE:
                break_while = True
                checkpointer.save("model_{:06d}".format(episode),
                                  **training_args)
                break

            if episode % cfg.TRAIN.SAVE_CKPT_EPISODE == 0:
                checkpointer.save("model_{:06d}".format(episode),
                                  **training_args)

            episode += 1

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))

    logger.info("Total training time: {} ({:.4f} s / epoch)".format(
        total_time_str, total_training_time /
        (episode - start_episode if episode > start_episode else 1)))
Ejemplo n.º 14
0
def do_train(cfg, model, train_dataloader, val_dataloader, logger, load_ckpt=None):
    # define optimizer
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE,
                              momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        raise NotImplementedError

    # define learning rate scheduler
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY)
    checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger)

    training_args = {}
    # training_args['iteration'] = 1
    training_args['epoch'] = 1
    training_args['val_best'] = 0.
    if load_ckpt:
        checkpointer.load_checkpoint(load_ckpt, strict=False)

    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load()
        training_args.update(extra_checkpoint_data)

    # start_iter = training_args['iteration']
    start_epoch = training_args['epoch']
    checkpointer.current_val_best = training_args['val_best']

    meters = MetricLogger(delimiter="  ")
    end = time.time()
    start_training_time = time.time()

    for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH + 1):
        training_args['epoch'] = epoch
        model.train()
        for inner_iter, data in enumerate(train_dataloader):
            # training_args['iteration'] = iteration
            # logger.info('inner_iter: {}, label: {}'.format(inner_iter, data['label_articleType'], len(data['label_articleType'])))
            data_time = time.time() - end
            inputs, targets = data['image'].to(device), data['label_articleType'].to(device)
            cls_scores = model(inputs, targets)
            losses = model.loss_evaluator(cls_scores, targets)
            metrics = model.metric_evaluator(cls_scores, targets)

            total_loss = sum(loss for loss in losses.values())
            meters.update(loss=total_loss, **losses, **metrics)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0:
                eta_seconds = meters.time.global_avg * (len(train_dataloader) * cfg.TRAIN.MAX_EPOCH -
                                                        (epoch - 1) * len(train_dataloader) - inner_iter)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                logger.info(
                    meters.delimiter.join(
                        [
                            "eta: {eta}",
                            "epoch: {ep}/{max_ep} (iter: {iter}/{max_iter})",
                            "{meters}",
                            "lr: {lr:.6f}",
                            "max mem: {memory:.0f}",
                        ]
                    ).format(
                        eta=eta_string,
                        ep=epoch,
                        max_ep=cfg.TRAIN.MAX_EPOCH,
                        iter=inner_iter,
                        max_iter=len(train_dataloader),
                        meters=str(meters),
                        lr=optimizer.param_groups[-1]["lr"],
                        memory=(
                                    torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0.,
                    )
                )

        if epoch % cfg.TRAIN.VAL_EPOCH == 0:
            logger.info('start evaluating at epoch {}'.format(epoch))
            val_metrics = do_eval(cfg, model, val_dataloader, logger, 'validation')
            if val_metrics.mean_class_accuracy > checkpointer.current_val_best:
                checkpointer.current_val_best = val_metrics.mean_class_accuracy
                training_args['val_best'] = checkpointer.current_val_best
                checkpointer.save("model_{:04d}_val_{:.4f}".format(epoch, checkpointer.current_val_best),
                                  **training_args)
                checkpointer.patience = 0
            else:
                checkpointer.patience += 1

            logger.info('current patience: {}/{}'.format(checkpointer.patience, cfg.TRAIN.PATIENCE))

        if epoch == cfg.TRAIN.MAX_EPOCH or epoch % cfg.TRAIN.SAVE_CKPT_EPOCH == 0 or checkpointer.patience == cfg.TRAIN.PATIENCE:
            checkpointer.save("model_{:04d}".format(epoch), **training_args)

        if checkpointer.patience == cfg.TRAIN.PATIENCE:
            logger.info('Max patience triggered. Early terminate training')
            break

        if epoch % cfg.TRAIN.LR_DECAY_EPOCH == 0:
            logger.info("lr decayed to {:.4f}".format(optimizer.param_groups[-1]["lr"]))
            lr_scheduler.step()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))

    logger.info(
        "Total training time: {} ({:.4f} s / epoch)".format(
            total_time_str, total_training_time / (epoch - start_epoch if epoch > start_epoch else 1)
        )
    )
def main():
    # argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_file',
                        help='path of config file',
                        default=None,
                        type=str)
    parser.add_argument('--clean_run',
                        help='run from scratch',
                        default=False,
                        type=bool)
    parser.add_argument('opts',
                        help='modify arguments',
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    # config setup
    if args.config_file is not None:
        cfg.merge_from_file(args.config_file)
    if args.opts is not None: cfg.merge_from_list(args.opts)

    cfg.freeze()
    if args.clean_run:
        if os.path.exists(f'../experiments/{cfg.SYSTEM.EXP_NAME}'):
            shutil.rmtree(f'../experiments/{cfg.SYSTEM.EXP_NAME}')
        if os.path.exists(f'../experiments/runs/{cfg.SYSTEM.EXP_NAME}'):
            shutil.rmtree(f'../experiments/runs/{cfg.SYSTEM.EXP_NAME}')
            # Note!: Sleeping to make tensorboard delete it's cache.
            time.sleep(5)

    search = defaultdict()
    search['lr'], search['momentum'], search['factor'], search['step_size'] = [
        True
    ] * 4
    set_seeds(cfg)
    logdir, chk_dir = save_config(cfg.SAVE_ROOT, cfg)
    writer = SummaryWriter(log_dir=logdir)
    # setup logger
    logger_dir = Path(chk_dir).parent
    logger = setup_logger(cfg.SYSTEM.EXP_NAME, save_dir=logger_dir)
    # Model
    prediction_model = BaseModule(cfg)
    noise_model = NoiseModule(cfg)
    model = [prediction_model, noise_model]
    device = cfg.SYSTEM.DEVICE if torch.cuda.is_available() else 'cpu'
    # load the data
    train_loader = get_loader(cfg, 'train')
    val_loader = get_loader(cfg, 'val')
    prediction_model, noise_model = model
    prediction_model.to(device)
    lr = cfg.SOLVER.LR
    momentum = cfg.SOLVER.MOMENTUM
    weight_decay = cfg.SOLVER.WEIGHT_DECAY
    betas = cfg.SOLVER.BETAS
    step_size = cfg.SOLVER.STEP_SIZE
    decay_factor = cfg.SOLVER.FACTOR

    # Optimizer
    if cfg.SOLVER.OPTIMIZER == 'Adam':
        optimizer = optim.Adam(prediction_model.parameters(),
                               lr=lr,
                               weight_decay=weight_decay,
                               betas=betas)
    elif cfg.SOLVER.OPTIMIZER == 'SGD':
        optimizer = optim.SGD(prediction_model.parameters(),
                              lr=lr,
                              weight_decay=weight_decay,
                              momentum=momentum)
    if cfg.SOLVER.SCHEDULER == 'StepLR':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=step_size,
                                              gamma=decay_factor)
    elif cfg.SOLVER.SCHEDULER == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=cfg.SOLVER.FACTOR,
            min_lr=cfg.SOLVER.MIN_LR,
            patience=cfg.SOLVER.PAITENCE,
            cooldown=cfg.SOLVER.COOLDOWN,
            threshold=cfg.SOLVER.THRESHOLD,
            eps=1e-24)
    # checkpointer
    chkpt = Checkpointer(prediction_model,
                         optimizer,
                         scheduler=scheduler,
                         save_dir=chk_dir,
                         logger=logger,
                         save_to_disk=True)
    offset = 0
    checkpointer = chkpt.load()
    if not checkpointer == {}:
        offset = checkpointer.pop('epoch')
    loader = [train_loader, val_loader]
    print(f'Same optimizer, {scheduler.optimizer == optimizer}')
    print(cfg)
    model = [prediction_model, noise_model]
    train(cfg, model, optimizer, scheduler, loader, chkpt, writer, offset)
    test_loader = get_loader(cfg, 'test')
    test(cfg, prediction_model, test_loader, writer, logger)