Ejemplo n.º 1
0
 def save_checkpoints(engine):
     metrics = engine.state.metrics
     wer = metrics['wer']
     cer = metrics['cer']
     epoch = trainer.state.epoch
     scheduler.step(wer)
     save_checkpoint(model, optimizer, best_meter, wer, cer, epoch)
     best_meter.update(wer, cer, epoch)
Ejemplo n.º 2
0
def main(args):
    my_devices = torch.device('cuda:' + str(args.gpu_id))
    '''Create Folders'''
    exp_root_dir = Path(os.path.join('./logs/nerfmm', args.scene_name))
    exp_root_dir.mkdir(parents=True, exist_ok=True)
    experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args)))
    experiment_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy('./models/nerf_models.py', experiment_dir)
    shutil.copy('./models/intrinsics.py', experiment_dir)
    shutil.copy('./models/poses.py', experiment_dir)
    shutil.copy('./tasks/nerfmm/train.py', experiment_dir)

    if args.store_pose_history:
        pose_history_dir = Path(os.path.join(experiment_dir, 'pose_history'))
        pose_history_dir.mkdir(parents=True, exist_ok=True)
    '''LOG'''
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt'))
    file_handler.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.WARNING)
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.info(args)
    '''Summary Writer'''
    writer = SummaryWriter(log_dir=str(experiment_dir))
    '''Data Loading'''
    scene_train = DataLoaderWithCOLMAP(base_dir=args.base_dir,
                                       scene_name=args.scene_name,
                                       data_type='train',
                                       res_ratio=args.resize_ratio,
                                       num_img_to_load=args.train_img_num,
                                       skip=args.train_skip,
                                       use_ndc=args.use_ndc)

    # The COLMAP eval poses are not in the same camera space that we learned so we can only check NVS
    # with a 4x4 identity pose.
    eval_c2ws = torch.eye(4).unsqueeze(0).float()  # (1, 4, 4)

    print('Train with {0:6d} images.'.format(scene_train.imgs.shape[0]))
    '''Model Loading'''
    pos_enc_in_dims = (2 * args.pos_enc_levels +
                       int(args.pos_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    if args.use_dir_enc:
        dir_enc_in_dims = (2 * args.dir_enc_levels +
                           int(args.dir_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    else:
        dir_enc_in_dims = 0

    model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims)
    if args.multi_gpu:
        model = torch.nn.DataParallel(model).to(device=my_devices)
    else:
        model = model.to(device=my_devices)

    # learn focal parameter
    if args.start_refine_focal_epoch > -1:
        focal_net = LearnFocal(scene_train.H,
                               scene_train.W,
                               args.learn_focal,
                               args.fx_only,
                               order=args.focal_order,
                               init_focal=scene_train.focal)
    else:
        focal_net = LearnFocal(scene_train.H,
                               scene_train.W,
                               args.learn_focal,
                               args.fx_only,
                               order=args.focal_order)
    if args.multi_gpu:
        focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices)
    else:
        focal_net = focal_net.to(device=my_devices)

    # learn pose for each image
    if args.start_refine_pose_epoch > -1:
        pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R,
                                   args.learn_t, scene_train.c2ws)
    else:
        pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R,
                                   args.learn_t, None)
    if args.multi_gpu:
        pose_param_net = torch.nn.DataParallel(pose_param_net).to(
            device=my_devices)
    else:
        pose_param_net = pose_param_net.to(device=my_devices)
    '''Set Optimiser'''
    optimizer_nerf = torch.optim.Adam(model.parameters(), lr=args.nerf_lr)
    optimizer_focal = torch.optim.Adam(focal_net.parameters(),
                                       lr=args.focal_lr)
    optimizer_pose = torch.optim.Adam(pose_param_net.parameters(),
                                      lr=args.pose_lr)

    scheduler_nerf = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_nerf,
        milestones=args.nerf_milestones,
        gamma=args.nerf_lr_gamma)
    scheduler_focal = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_focal,
        milestones=args.focal_milestones,
        gamma=args.focal_lr_gamma)
    scheduler_pose = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_pose,
        milestones=args.pose_milestones,
        gamma=args.pose_lr_gamma)
    '''Training'''
    for epoch_i in tqdm(range(args.epoch), desc='epochs'):
        rgb_act_fn = torch.sigmoid
        train_epoch_losses = train_one_epoch(scene_train, optimizer_nerf,
                                             optimizer_focal, optimizer_pose,
                                             model, focal_net, pose_param_net,
                                             my_devices, args, rgb_act_fn,
                                             epoch_i)
        train_L2_loss = train_epoch_losses['L2']
        scheduler_nerf.step()
        scheduler_focal.step()
        scheduler_pose.step()

        train_psnr = mse2psnr(train_L2_loss)
        writer.add_scalar('train/mse', train_L2_loss, epoch_i)
        writer.add_scalar('train/psnr', train_psnr, epoch_i)
        writer.add_scalar('train/lr', scheduler_nerf.get_lr()[0], epoch_i)
        logger.info('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))
        tqdm.write('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))

        pose_history_milestone = list(range(0, 100, 5)) + list(
            range(100, 1000, 100)) + list(range(1000, 10000, 1000))
        if epoch_i in pose_history_milestone:
            with torch.no_grad():
                if args.store_pose_history:
                    store_current_pose(pose_param_net, pose_history_dir,
                                       epoch_i)

        if epoch_i % args.eval_cam_interval == 0 and epoch_i > 0:
            with torch.no_grad():
                eval_stats_tran, eval_stats_rot, eval_stats_scale = eval_one_epoch_traj(
                    scene_train, pose_param_net)
                writer.add_scalar('eval/traj/translation',
                                  eval_stats_tran['mean'], epoch_i)
                writer.add_scalar('eval/traj/rotation', eval_stats_rot['mean'],
                                  epoch_i)
                writer.add_scalar('eval/traj/scale', eval_stats_scale['mean'],
                                  epoch_i)

                logger.info(
                    '{0:6d} ep Traj Err: translation: {1:.6f}, rotation: {2:.2f} deg, scale: {3:.2f}'
                    .format(epoch_i, eval_stats_tran['mean'],
                            eval_stats_rot['mean'], eval_stats_scale['mean']))
                tqdm.write(
                    '{0:6d} ep Traj Err: translation: {1:.6f}, rotation: {2:.2f} deg, scale: {3:.2f}'
                    .format(epoch_i, eval_stats_tran['mean'],
                            eval_stats_rot['mean'], eval_stats_scale['mean']))

                fxfy = focal_net(0)
                tqdm.write(
                    'Est fx: {0:.2f}, fy {1:.2f}, COLMAP focal: {2:.2f}'.
                    format(fxfy[0].item(), fxfy[1].item(), scene_train.focal))
                logger.info(
                    'Est fx: {0:.2f}, fy {1:.2f}, COLMAP focal: {2:.2f}'.
                    format(fxfy[0].item(), fxfy[1].item(), scene_train.focal))
                if torch.is_tensor(fxfy):
                    L1_focal = torch.abs(fxfy -
                                         scene_train.focal).mean().item()
                else:
                    L1_focal = np.abs(fxfy - scene_train.focal).mean()
                writer.add_scalar('eval/L1_focal', L1_focal, epoch_i)

        if epoch_i % args.eval_img_interval == 0 and epoch_i > 0:
            with torch.no_grad():
                eval_one_epoch_img(eval_c2ws, scene_train, model, focal_net,
                                   pose_param_net, my_devices, args, epoch_i,
                                   writer, rgb_act_fn)

                # save the latest model.
                save_checkpoint(epoch_i,
                                model,
                                optimizer_nerf,
                                experiment_dir,
                                ckpt_name='latest_nerf')
                save_checkpoint(epoch_i,
                                focal_net,
                                optimizer_focal,
                                experiment_dir,
                                ckpt_name='latest_focal')
                save_checkpoint(epoch_i,
                                pose_param_net,
                                optimizer_pose,
                                experiment_dir,
                                ckpt_name='latest_pose')
    return
Ejemplo n.º 3
0
def train(model, optimizer, train_loader, valid_loader, num_epochs, eval_every,
          file_path, best_valid_loss=float("Inf")):
    # initialize running values
    running_loss = 0.0
    valid_running_loss = 0.0
    global_step = 0
    train_loss_list = []
    valid_loss_list = []
    global_steps_list = []

    # training loop
    model.train()
    print("init train")
    for epoch in range(num_epochs):
        for labels, text in train_loader:
            labels = labels.type(torch.LongTensor)
            labels = labels.to(device)
            text = text.type(torch.LongTensor)
            text = text.to(device)
            output = model(text, labels)
            loss, _ = output

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update running values
            running_loss += loss.item()
            global_step += 1

            # evaluation step
            if global_step % eval_every == 0:
                model.eval()
                with torch.no_grad():

                    # validation loop
                    for val_labels, val_text in valid_loader:
                        val_labels = val_labels.type(torch.LongTensor)
                        val_labels = val_labels.to(device)
                        val_text = val_text.type(torch.LongTensor)
                        val_text = val_text.to(device)
                        output = model(val_text, val_labels)
                        loss, _ = output

                        valid_running_loss += loss.item()

                # evaluation
                average_train_loss = running_loss / eval_every
                average_valid_loss = valid_running_loss / len(valid_loader)
                train_loss_list.append(average_train_loss)
                valid_loss_list.append(average_valid_loss)
                global_steps_list.append(global_step)

                # resetting running values
                running_loss = 0.0
                valid_running_loss = 0.0
                model.train()

                # print progress
                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, global_step, num_epochs * len(train_loader),
                              average_train_loss, average_valid_loss))

                # checkpoint
                if best_valid_loss > average_valid_loss:
                    best_valid_loss = average_valid_loss
                    save_checkpoint(file_path + '/' + 'model2.pt', model, best_valid_loss)

    print('Finished Training!')
Ejemplo n.º 4
0
def main(args):
    my_devices = torch.device('cuda:' + str(args.gpu_id))
    '''Create Folders'''
    exp_root_dir = Path(os.path.join('./logs/any_folder', args.scene_name))
    exp_root_dir.mkdir(parents=True, exist_ok=True)
    experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args)))
    experiment_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy('./models/nerf_models.py', experiment_dir)
    shutil.copy('./models/intrinsics.py', experiment_dir)
    shutil.copy('./models/poses.py', experiment_dir)
    shutil.copy('./tasks/any_folder/train.py', experiment_dir)
    '''LOG'''
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt'))
    file_handler.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.WARNING)
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.info(args)
    '''Summary Writer'''
    writer = SummaryWriter(log_dir=str(experiment_dir))
    '''Data Loading'''
    scene_train = DataLoaderAnyFolder(base_dir=args.base_dir,
                                      scene_name=args.scene_name,
                                      res_ratio=args.resize_ratio,
                                      num_img_to_load=args.train_img_num,
                                      start=args.train_start,
                                      end=args.train_end,
                                      skip=args.train_skip,
                                      load_sorted=args.train_load_sorted)

    print('Train with {0:6d} images.'.format(scene_train.imgs.shape[0]))

    # We have no eval pose in this any_folder task. Eval with a 4x4 identity pose.
    eval_c2ws = torch.eye(4).unsqueeze(0).float()  # (1, 4, 4)
    '''Model Loading'''
    pos_enc_in_dims = (2 * args.pos_enc_levels +
                       int(args.pos_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    if args.use_dir_enc:
        dir_enc_in_dims = (2 * args.dir_enc_levels +
                           int(args.dir_enc_inc_in)) * 3  # (2L + 0 or 1) * 3
    else:
        dir_enc_in_dims = 0

    model = OfficialNerf(pos_enc_in_dims, dir_enc_in_dims, args.hidden_dims)
    if args.multi_gpu:
        model = torch.nn.DataParallel(model).to(device=my_devices)
    else:
        model = model.to(device=my_devices)

    # learn focal parameter
    focal_net = LearnFocal(scene_train.H,
                           scene_train.W,
                           args.learn_focal,
                           args.fx_only,
                           order=args.focal_order)
    if args.multi_gpu:
        focal_net = torch.nn.DataParallel(focal_net).to(device=my_devices)
    else:
        focal_net = focal_net.to(device=my_devices)

    # learn pose for each image
    pose_param_net = LearnPose(scene_train.N_imgs, args.learn_R, args.learn_t,
                               None)
    if args.multi_gpu:
        pose_param_net = torch.nn.DataParallel(pose_param_net).to(
            device=my_devices)
    else:
        pose_param_net = pose_param_net.to(device=my_devices)
    '''Set Optimiser'''
    optimizer_nerf = torch.optim.Adam(model.parameters(), lr=args.nerf_lr)
    optimizer_focal = torch.optim.Adam(focal_net.parameters(),
                                       lr=args.focal_lr)
    optimizer_pose = torch.optim.Adam(pose_param_net.parameters(),
                                      lr=args.pose_lr)

    scheduler_nerf = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_nerf,
        milestones=args.nerf_milestones,
        gamma=args.nerf_lr_gamma)
    scheduler_focal = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_focal,
        milestones=args.focal_milestones,
        gamma=args.focal_lr_gamma)
    scheduler_pose = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_pose,
        milestones=args.pose_milestones,
        gamma=args.pose_lr_gamma)
    '''Training'''
    for epoch_i in tqdm(range(args.epoch), desc='epochs'):
        rgb_act_fn = torch.sigmoid
        train_epoch_losses = train_one_epoch(scene_train, optimizer_nerf,
                                             optimizer_focal, optimizer_pose,
                                             model, focal_net, pose_param_net,
                                             my_devices, args, rgb_act_fn)
        train_L2_loss = train_epoch_losses['L2']
        scheduler_nerf.step()
        scheduler_focal.step()
        scheduler_pose.step()

        train_psnr = mse2psnr(train_L2_loss)
        writer.add_scalar('train/mse', train_L2_loss, epoch_i)
        writer.add_scalar('train/psnr', train_psnr, epoch_i)
        writer.add_scalar('train/lr', scheduler_nerf.get_lr()[0], epoch_i)
        logger.info('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))
        tqdm.write('{0:6d} ep: Train: L2 loss: {1:.4f}, PSNR: {2:.3f}'.format(
            epoch_i, train_L2_loss, train_psnr))

        if epoch_i % args.eval_interval == 0 and epoch_i > 0:
            with torch.no_grad():
                eval_one_epoch(eval_c2ws, scene_train, model, focal_net,
                               pose_param_net, my_devices, args, epoch_i,
                               writer, rgb_act_fn)

                fxfy = focal_net(0)
                tqdm.write('Est fx: {0:.2f}, fy {1:.2f}'.format(
                    fxfy[0].item(), fxfy[1].item()))
                logger.info('Est fx: {0:.2f}, fy {1:.2f}'.format(
                    fxfy[0].item(), fxfy[1].item()))

                # save the latest model
                save_checkpoint(epoch_i,
                                model,
                                optimizer_nerf,
                                experiment_dir,
                                ckpt_name='latest_nerf')
                save_checkpoint(epoch_i,
                                focal_net,
                                optimizer_focal,
                                experiment_dir,
                                ckpt_name='latest_focal')
                save_checkpoint(epoch_i,
                                pose_param_net,
                                optimizer_pose,
                                experiment_dir,
                                ckpt_name='latest_pose')
    return
Ejemplo n.º 5
0
def main(args):
    '''Create the log dir for this run'''
    exp_root_dir = Path('./logs/')
    exp_root_dir.mkdir(parents=True, exist_ok=True)
    experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args)))
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # copy train and model file to ensure reproducibility.
    shutil.copy('./model.py', experiment_dir)
    shutil.copy('./train.py', experiment_dir)

    '''Logger'''
    logger = logging.getLogger("NINormalNetTrain")
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt'))
    file_handler.setLevel(logging.INFO)
    logger.addHandler(file_handler)
    logger.info(args)

    '''Summary Writer'''
    writer = SummaryWriter(log_dir=experiment_dir)

    '''Data Loading'''
    logger.info('Load dataset ...')
    train_dataset = PcpKnnPatchesDataset(datafolder=args.datafolder,
                                         dataset_type='train',
                                         dataset_name=args.train_dataset_name,
                                         fastdebug=args.fastdebug,
                                         noise_level=args.train_noise_level)

    eval_dataset = PcpKnnPatchesDataset(datafolder=args.datafolder,
                                        dataset_type='eval',
                                        dataset_name=args.eval_dataset_name,
                                        fastdebug=args.fastdebug,
                                        noise_level=args.eval_noise_level)

    train_dataloader = DataLoader(train_dataset, batch_size=args.batchsize_train,
                                  shuffle=True, num_workers=10, pin_memory=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=args.batchsize_eval,
                                 shuffle=False, num_workers=10, pin_memory=True)

    '''Model Loading'''
    model = NINormalNet()
    if args.multi_gpu:
        model = torch.nn.DataParallel(model).to(device='cuda:' + str(args.gpu_id))
    else:
        model = model.to(device='cuda:'+str(args.gpu_id))

    '''Set Optimiser'''
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.L2_reg)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_gamma)

    '''Training'''
    best_temp = 0.0
    best_eval_pgp010 = 0.0
    sys.stdout.flush()
    logger.info('Start training...')
    for epoch in tqdm(range(args.epoch), desc='epochs'):
        logger.info('Epoch (%d/%s):', epoch + 1, args.epoch)

        train_epoch_metric = train_one_epoch(train_dataloader, optimizer, model, args)
        with torch.no_grad():
            eval_epoch_metric = eval_one_epoch(eval_dataloader, model, args)
        scheduler.step()

        tqdm.write('pgp010: train: {0:.4f}, eval: {1:.4f}'.format(train_epoch_metric['pgp010'], eval_epoch_metric['pgp010']))
        logger.info('pgp010: train: {0:.4f}, eval: {1:.4f}'.format(train_epoch_metric['pgp010'], eval_epoch_metric['pgp010']))

        writer.add_scalar('train/loss', train_epoch_metric['loss'], epoch)
        writer.add_scalar('train/pgp003', train_epoch_metric['pgp003'], epoch)
        writer.add_scalar('train/pgp005', train_epoch_metric['pgp005'], epoch)
        writer.add_scalar('train/pgp010', train_epoch_metric['pgp010'], epoch)
        writer.add_scalar('train/pgp030', train_epoch_metric['pgp030'], epoch)
        writer.add_scalar('train/pgp060', train_epoch_metric['pgp060'], epoch)
        writer.add_scalar('train/pgp080', train_epoch_metric['pgp080'], epoch)
        writer.add_scalar('train/pgp090', train_epoch_metric['pgp090'], epoch)  # sanity check, should be always 1.0 in un-oriented normal estimation
        writer.add_scalar('train/lr', scheduler.get_lr()[0], epoch)
        writer.add_scalar('temp', model.module.temp if args.multi_gpu else model.temp, epoch)

        writer.add_scalar('eval/loss', eval_epoch_metric['loss'], epoch)
        writer.add_scalar('eval/pgp003', eval_epoch_metric['pgp003'], epoch)
        writer.add_scalar('eval/pgp005', eval_epoch_metric['pgp005'], epoch)
        writer.add_scalar('eval/pgp010', eval_epoch_metric['pgp010'], epoch)
        writer.add_scalar('eval/pgp030', eval_epoch_metric['pgp030'], epoch)
        writer.add_scalar('eval/pgp060', eval_epoch_metric['pgp060'], epoch)
        writer.add_scalar('eval/pgp080', eval_epoch_metric['pgp080'], epoch)
        writer.add_scalar('eval/pgp090', eval_epoch_metric['pgp090'], epoch)  # sanity check, should be always 1.0 in un-oriented normal estimation

        if eval_epoch_metric['pgp010'] >= best_eval_pgp010:
            best_eval_pgp010 = eval_epoch_metric['pgp010']
            best_temp = model.module.temp if args.multi_gpu else model.temp

            logger.info('Saving model with the best pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp))
            tqdm.write('Saving model with the best pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp))
            save_checkpoint(epoch, train_epoch_metric['pgp010'], eval_epoch_metric['pgp010'], model,
                            optimizer, str(experiment_dir), args.model_name)

    print('Best eval pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp))
    logger.info('Best eval pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp))

    print('Final temp: {0:.4f}'.format(model.module.temp if args.multi_gpu else model.temp))
    logger.info('Final temp: {0:.4f}'.format(model.module.temp if args.multi_gpu else model.temp))

    return
Ejemplo n.º 6
0
def train_seg(args):
    single_model = DenseSeg(model_name=args.arch,
                            classes=args.num_class,
                            transition_layer=args.transition_layer,
                            conv_num_features=args.conv_num_features,
                            out_channels_num=args.out_channels_num,
                            ppl_out_channels_num=args.ppl_out_channels_num,
                            dilation=args.dilation,
                            pretrained=True)
    model = torch.nn.DataParallel(single_model).cuda()
    if args.pretrained:
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint['state_dict'])
    criterion = nn.NLLLoss2d(ignore_index=255)
    criterion.cuda()

    # Data loading code
    info = json.load(open(os.path.join(args.data_dir, 'info.json'), 'r'))

    # data augmentation
    t = []
    normalize = data_transform_utils.Normalize(mean=info['mean'],
                                               std=info['std'])
    if args.random_rotate > 0:
        t.append(data_transform_utils.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(data_transform_utils.RandomScale(args.random_scale))
    t.extend([
        data_transform_utils.RandomCrop(args.crop_size),
        data_transform_utils.RandomHorizontalFlip(),
        data_transform_utils.ToTensor(), normalize
    ])

    train_loader = torch.utils.data.DataLoader(SegList(
        args.data_dir, 'train', data_transform_utils.Compose(t)),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(SegList(
        args.data_dir, 'val',
        data_transform_utils.Compose(
            [data_transform_utils.ToTensor(), normalize])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True,
                                             drop_last=True)

    # define loss function (criterion) and pptimizer
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, eval_score=accuracy)
        return

    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch
        train(args,
              train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              eval_score=accuracy)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, eval_score=accuracy)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        # Model saving
        checkpoint_path = './model_save_dir/checkpoint_latest' + args.model_save_suffix
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)

        if (epoch + 1) % 10 == 0:
            history_path = './model_save_dir/checkpoint_{:03d}_{:s}'.format(
                epoch + 1, args.model_save_suffix)
            shutil.copyfile(checkpoint_path, history_path)