Exemple #1
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    if getattr(FLAGS, 'use_diff_seed', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(
            getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()
    # experiment setting
    experiment_setting = get_experiment_setting()

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(train_transforms, val_transforms,
                                           test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)

    checkpoint = torch.load(os.path.join(log_dir, 'best_model.pt'),
                            map_location=lambda storage, loc: storage)
    model_wrapper.load_state_dict(checkpoint['model'])
    optimizer = get_optimizer(model_wrapper)

    mprint('Start testing.')
    test_meters = get_meters('test')
    with torch.no_grad():
        run_one_epoch(-1,
                      test_loader,
                      model_wrapper,
                      criterion,
                      optimizer,
                      test_meters,
                      phase='test',
                      ema=ema)
Exemple #2
0
def validate(config, testloader, model, writer_dict):
    model.eval()
    ave_loss = AverageMeter()
    nums = config.MODEL.NUM_OUTPUTS
    confusion_matrix = np.zeros(
        (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES, nums))
    with torch.no_grad():
        for idx, batch in enumerate(testloader):
            image, label, _, _ = batch
            size = label.size()
            image = image.cuda()
            label = label.long().cuda()

            losses, pred = model(image, label)
            if not isinstance(pred, (list, tuple)):
                pred = [pred]
            for i, x in enumerate(pred):
                x = F.interpolate(
                    input=x, size=size[-2:],
                    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
                )

                confusion_matrix[..., i] += get_confusion_matrix(
                    label,
                    x,
                    size,
                    config.DATASET.NUM_CLASSES,
                    config.TRAIN.IGNORE_LABEL
                )

            if idx % 10 == 0:
                print(idx)

            loss = losses.mean()
            if dist.is_distributed():
                reduced_loss = reduce_tensor(loss)
            else:
                reduced_loss = loss
            ave_loss.update(reduced_loss.item())

    if dist.is_distributed():
        confusion_matrix = torch.from_numpy(confusion_matrix).cuda()
        reduced_confusion_matrix = reduce_tensor(confusion_matrix)
        confusion_matrix = reduced_confusion_matrix.cpu().numpy()

    for i in range(nums):
        pos = confusion_matrix[..., i].sum(1)
        res = confusion_matrix[..., i].sum(0)
        tp = np.diag(confusion_matrix[..., i])
        IoU_array = (tp / np.maximum(1.0, pos + res - tp))
        mean_IoU = IoU_array.mean()
        if dist.get_rank() <= 0:
            logging.info('{} {} {}'.format(i, IoU_array, mean_IoU))

    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']
    writer.add_scalar('valid_loss', ave_loss.average(), global_steps)
    writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
    writer_dict['valid_global_steps'] = global_steps + 1
    return ave_loss.average(), mean_IoU, IoU_array
Exemple #3
0
def train(config, epoch, num_epoch, epoch_iters, base_lr, num_iters,
          trainloader, optimizer, model, writer_dict):
    # Training
    model.train()
    scaler = GradScaler()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    ave_acc = AverageMeter()
    tic = time.time()
    cur_iters = epoch * epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']

    for i_iter, batch in enumerate(trainloader, 0):
        images, labels, _, _ = batch
        images = images.cuda()
        # print("images:",images.size())
        labels = labels.long().cuda()
        # print("label:",labels.size())
        with autocast():
            losses, _, acc = model(images, labels)
        loss = losses.mean()
        acc = acc.mean()

        if dist.is_distributed():
            reduced_loss = reduce_tensor(loss)
        else:
            reduced_loss = loss

        model.zero_grad()
        scaler.scale(loss).backward()
        #  loss.backward()

        #optimizer.step()
        scaler.step(optimizer)
        scaler.update()
        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(reduced_loss.item())
        ave_acc.update(acc.item())

        lr = adjust_learning_rate(optimizer, base_lr, num_iters,
                                  i_iter + cur_iters)

        if i_iter % config.PRINT_FREQ == 0 and dist.get_rank() == 0:
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {}, Loss: {:.6f}, Acc:{:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters,
                      batch_time.average(), [x['lr'] for x in optimizer.param_groups], ave_loss.average(),
                      ave_acc.average())
            logging.info(msg)

    writer.add_scalar('train_loss', ave_loss.average(), global_steps)
    writer_dict['train_global_steps'] = global_steps + 1
Exemple #4
0
def check_dist_init(config, logger):
    # check distributed initialization
    if config.distributed.enable:
        import os
        # for slurm
        try:
            node_id = int(os.environ['SLURM_NODEID'])
        except KeyError:
            return

        rank = dist.get_rank()
        world_size = dist.get_world_size()
        gpu_id = dist.gpu_id

        logger.info('World: {}/Node: {}/Rank: {}/GpuId: {} initialized.'
                    .format(world_size, node_id, rank, gpu_id))
Exemple #5
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    if getattr(FLAGS, 'use_diff_seed', False) and not getattr(FLAGS, 'stoch_valid', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()

    # experiment setting
    experiment_setting = get_experiment_setting()

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(
        train_transforms, val_transforms, test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)

    # full precision pretrained
    if getattr(FLAGS, 'fp_pretrained_file', None):
        checkpoint = torch.load(
            FLAGS.fp_pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_dict = model_wrapper.state_dict()
        #checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict}
        # switch bn
        for k in list(checkpoint.keys()):
            if 'bn' in k:
                for bn_idx in range(len(FLAGS.bits_list)):
                    k_new = k.split('bn')[0] + 'bn' + k.split('bn')[1][0] + str(bn_idx) + k.split('bn')[1][2:]
                    mprint(k)
                    mprint(k_new)
                    checkpoint[k_new] = model_dict[k]
        if getattr(FLAGS, 'switch_alpha', False):
            for k, v in checkpoint.items():
                if 'alpha' in k and checkpoint[k].size() != model_dict[k].size():
                    #checkpoint[k] = checkpoint[k].repeat(model_dict[k].size())
                    checkpoint[k] = nn.Parameter(torch.stack([checkpoint[k] for _ in range(model_dict[k].size()[0])]))
        # remove unexpected keys
        for k in list(checkpoint.keys()):
            if k not in model_dict.keys():
                checkpoint.pop(k)
        model_dict.update(checkpoint)
        model_wrapper.load_state_dict(model_dict)
        mprint('Loaded full precision model {}.'.format(FLAGS.fp_pretrained_file))

    # check pretrained
    if FLAGS.pretrained_file:
        pretrained_dir = FLAGS.pretrained_dir
        pretrained_dir = os.path.join(pretrained_dir, experiment_setting)
        pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file)
        checkpoint = torch.load(
            pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        mprint('Loaded model {}.'.format(pretrained_file))
    optimizer = get_optimizer(model_wrapper)

    if FLAGS.test_only and (test_loader is not None):
        mprint('Start profiling.')
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        mprint('Start testing.')
        test_meters = get_meters('test')
        with torch.no_grad():
            run_one_epoch(
                -1, test_loader,
                model_wrapper, criterion, optimizer,
                test_meters, phase='test')
        return

    # check resume training
    if os.path.exists(os.path.join(log_dir, 'latest_checkpoint.pt')):
        checkpoint = torch.load(
            os.path.join(log_dir, 'latest_checkpoint.pt'),
            map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        last_epoch = checkpoint['last_epoch']
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
            lr_scheduler.last_epoch = last_epoch * len(train_loader)
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
            lr_scheduler.last_epoch = last_epoch
        best_val = checkpoint['best_val']
        train_meters, val_meters = checkpoint['meters']
        mprint('Loaded checkpoint {} at epoch {}.'.format(
            log_dir, last_epoch))
    else:
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
        last_epoch = lr_scheduler.last_epoch
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        # if start from scratch, print model and do profiling
        mprint(model_wrapper)
        if getattr(FLAGS, 'profiling', False):
            if 'gpu' in FLAGS.profiling:
                profiling(model, use_cuda=True)
            if 'cpu' in FLAGS.profiling:
                profiling(model, use_cuda=False)

    if getattr(FLAGS, 'log_dir', None):
        try:
            os.makedirs(log_dir)
        except OSError:
            pass

    mprint('Start training.')

    for epoch in range(last_epoch+1, FLAGS.num_epochs):
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_sched = lr_scheduler
        else:
            lr_sched = None
            # For PyTorch 1.1+, comment the following line
            #lr_scheduler.step()
        # train
        mprint(' train '.center(40, '*'))
        run_one_epoch(
            epoch, train_loader, model_wrapper, criterion, optimizer,
            train_meters, phase='train', scheduler=lr_sched)

        # val
        mprint(' validation '.center(40, '~'))
        if val_meters is not None:
            val_meters['best_val'].cache(best_val)
        with torch.no_grad():
            top1_error = run_one_epoch(
                epoch, val_loader, model_wrapper, criterion, optimizer,
                val_meters, phase='val')
        if is_master():
            if top1_error < best_val:
                best_val = top1_error
                torch.save(
                    {
                        'model': model_wrapper.state_dict(),
                    },
                    os.path.join(log_dir, 'best_model.pt'))
                mprint('New best validation top1 error: {:.3f}'.format(best_val))

            # save latest checkpoint
            torch.save(
                {
                    'model': model_wrapper.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'last_epoch': epoch,
                    'best_val': best_val,
                    'meters': (train_meters, val_meters),
                },
                os.path.join(log_dir, 'latest_checkpoint.pt'))

        # For PyTorch 1.0 or earlier, comment the following two lines
        if FLAGS.lr_scheduler not in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']:
            lr_scheduler.step()

    if is_master():
        profiling(model, use_cuda=True)

    return
Exemple #6
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    #if getattr(FLAGS, 'use_diff_seed', False):
    #if getattr(FLAGS, 'use_diff_seed', False) and not FLAGS.test_only:
    if getattr(FLAGS, 'use_diff_seed', False) and not getattr(FLAGS, 'stoch_valid', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()

    # experiment setting
    experiment_setting = get_experiment_setting()

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    #
    ema_decay = getattr(FLAGS, 'ema_decay', None)
    if ema_decay:
        ema = EMA(ema_decay)
        ema.shadow_register(model_wrapper)
        #for name, param in model.named_parameters():
        #    if param.requires_grad:
        #        ema.register(name, param.data)
        #bn_idx = 0
        #for m in model.modules():
        #    if isinstance(m, nn.BatchNorm2d):
        #        ema.register('bn{}_mean'.format(bn_idx), m.running_mean)
        #        ema.register('bn{}_var'.format(bn_idx), m.running_var)
        #        bn_idx += 1
    else:
        ema = None

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(
        train_transforms, val_transforms, test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)
    io = UltronIO('hdfs://haruna/home')
    # full precision pretrained
    if getattr(FLAGS, 'fp_pretrained_file', None):
        checkpoint = io.torch_load(
            FLAGS.fp_pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_dict = model_wrapper.state_dict()
        #checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict}
        # remove unexpected keys
        for k in list(checkpoint.keys()):
            if k not in model_dict.keys():
                checkpoint.pop(k)
        model_dict.update(checkpoint)
        model_wrapper.load_state_dict(model_dict)
        mprint('Loaded full precision model {}.'.format(FLAGS.fp_pretrained_file))

    # check pretrained
    if FLAGS.pretrained_file and FLAGS.pretrained_dir:
        pretrained_dir = FLAGS.pretrained_dir
        #pretrained_dir = os.path.join(pretrained_dir, experiment_setting)
        pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file)
        checkpoint = io.torch_load(
            pretrained_file, map_location=lambda storage, loc: storage)
        # update keys from external models
        #if type(checkpoint) == dict and 'model' in checkpoint:
        #    checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        # filter lamda_w and lamda_a args:
        pretrained_dict = {}
        for k,v in checkpoint['model'].items():
            if 'lamda_w' in k or 'lamda_a' in k:
                checkpoint['model'][k] = v.repeat(model_wrapper.state_dict()[k].size())
        model_wrapper.load_state_dict(checkpoint['model'])
        mprint('Loaded model {}.'.format(pretrained_file))
    optimizer = get_optimizer(model_wrapper)

    if FLAGS.test_only and (test_loader is not None):
        mprint('Start testing.')
        ema = checkpoint.get('ema', None)
        test_meters = get_meters('test')
        with torch.no_grad():
            run_one_epoch(
                -1, test_loader,
                model_wrapper, criterion, optimizer,
                test_meters, phase='test', ema=ema)
        return

    # check resume training
    if io.check_path(os.path.join(log_dir, 'latest_checkpoint.pt')):
        checkpoint = io.torch_load(
            os.path.join(log_dir, 'latest_checkpoint.pt'),
            map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        last_epoch = checkpoint['last_epoch']
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
            lr_scheduler.last_epoch = last_epoch * len(train_loader)
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
            lr_scheduler.last_epoch = last_epoch
        best_val = checkpoint['best_val']
        train_meters, val_meters = checkpoint['meters']
        ema = checkpoint.get('ema', None)
        mprint('Loaded checkpoint {} at epoch {}.'.format(
            log_dir, last_epoch))
    else:
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_scheduler = get_lr_scheduler(optimizer, len(train_loader))
        else:
            lr_scheduler = get_lr_scheduler(optimizer)
        last_epoch = lr_scheduler.last_epoch
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        # if start from scratch, print model and do profiling
        mprint(model_wrapper)
        if getattr(FLAGS, 'profiling', False):
            if 'gpu' in FLAGS.profiling:
                profiling(model, use_cuda=True)
            if 'cpu' in FLAGS.profiling:
                profiling(model, use_cuda=False)

    if getattr(FLAGS, 'log_dir', None):
        try:
            io.create_folder(log_dir)
        except OSError:
            pass

    mprint('Start training.')
    for epoch in range(last_epoch+1, FLAGS.num_epochs):
        if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_sched = lr_scheduler
        else:
            lr_sched = None
            # For PyTorch 1.1+, comment the following line
            #lr_scheduler.step()
        # train
        mprint(' train '.center(40, '*'))
        run_one_epoch(
          epoch, train_loader, model_wrapper, criterion, optimizer,
          train_meters, phase='train', ema=ema, scheduler=lr_sched)

        # val
        mprint(' validation '.center(40, '~'))
        if val_meters is not None:
            val_meters['best_val'].cache(best_val)
        with torch.no_grad():
            if epoch == getattr(FLAGS,'hard_assign_epoch', float('inf')):
                mprint('Start to use hard assigment')
                setattr(FLAGS, 'hard_assignment', True)
                lower_offset = -1
                higher_offset = 0
                setattr(FLAGS, 'hard_offset', 0)


                with_ratio = 0.01
                bitops, bytesize = profiling(model, use_cuda=True)
                search_trials = 10
                trial = 0
                if getattr(FLAGS,'weight_only', False):
                    target_bytesize = getattr(FLAGS, 'target_size', 0)
                    while trial < search_trials:
                        trial += 1
                        if bytesize - target_bytesize > with_ratio * target_bytesize:
                            higher_offset = FLAGS.hard_offset
                        elif bytesize - target_bytesize < -with_ratio * target_bytesize:
                            lower_offset = FLAGS.hard_offset
                        else:
                            break
                        FLAGS.hard_offset = (higher_offset + lower_offset) /2
                        bitops, bytesize = profiling(model, use_cuda=True)
                else:
                    target_bitops = getattr(FLAGS, 'target_bitops',0)
                    while trial < search_trials:
                        trial += 1
                        if bitops - target_bitops > with_ratio *target_bitops:
                            higher_offset = FLAGS.hard_offset
                        elif bitops - target_bitops < -with_ratio * target_bitops:
                            lower_offset = FLAGS.hard_offset
                        else:
                            break
                        FLAGS.hard_offset = (higher_offset + lower_offset) /2
                        bitops, bytesize = profiling(model, use_cuda=True)
                bit_discretizing(model_wrapper)
                setattr(FLAGS,'hard_offset', 0)
            top1_error = run_one_epoch(
                epoch, val_loader, model_wrapper, criterion, optimizer,
                val_meters, phase='val', ema=ema)
        if is_master():
            if top1_error < best_val:
                best_val = top1_error
                io.torch_save(
                    os.path.join(log_dir, 'best_model.pt'),
                    {
                        'model': model_wrapper.state_dict(),
                    }
                    )
                mprint('New best validation top1 error: {:.3f}'.format(best_val))

            # save latest checkpoint
            io.torch_save(
                os.path.join(log_dir, 'latest_checkpoint.pt'),
                {
                    'model': model_wrapper.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'last_epoch': epoch,
                    'best_val': best_val,
                    'meters': (train_meters, val_meters),
                    'ema': ema,
                })

        # For PyTorch 1.0 or earlier, comment the following two lines
        if FLAGS.lr_scheduler not in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']:
            lr_scheduler.step()

    if is_master():
        profiling(model, use_cuda=True)
        for m in model.modules():
            if hasattr(m, 'alpha'):
                mprint(m, m.alpha)
            if hasattr(m, 'lamda_w'):
                mprint(m, m.lamda_w)
            if hasattr(m, 'lamda_a'):
                mprint(m, m.lamda_a)
    return
Exemple #7
0
def train(train_dataloader, model, optimizer, lr_scheduler):
    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    logger.info("model\n{}".format(describe(model.module)))
    tb_writer = SummaryWriter(cfg.TRAIN.LOG_DIR)
    average_meter = AverageMeter()
    start_epoch = cfg.TRAIN.START_EPOCH
    world_size = get_world_size()
    num_per_epoch = len(
        train_dataloader.dataset) // (cfg.TRAIN.BATCH_SIZE * world_size)
    iter = 0
    if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR) and get_rank() == 0:
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)
    for epoch in range(cfg.TRAIN.START_EPOCH, cfg.TRAIN.EPOCHS):
        if cfg.BACKBONE.TRAIN_EPOCH == epoch:
            logger.info('begin to train backbone!')
            optimizer, lr_scheduler = build_optimizer_lr(model.module, epoch)
            logger.info("model\n{}".format(describe(model.module)))
        train_dataloader.dataset.shuffle()
        lr_scheduler.step(epoch)
        # log for lr
        if get_rank() == 0:
            for idx, pg in enumerate(optimizer.param_groups):
                tb_writer.add_scalar('lr/group{}'.format(idx + 1), pg['lr'],
                                     iter)
        cur_lr = lr_scheduler.get_cur_lr()
        for data in train_dataloader:
            begin = time.time()
            examplar_img = data['examplar_img'].cuda()
            search_img = data['search_img'].cuda()
            gt_cls = data['gt_cls'].cuda()
            gt_delta = data['gt_delta'].cuda()
            delta_weight = data['delta_weight'].cuda()
            data_time = time.time() - begin
            losses = model.forward(examplar_img, search_img, gt_cls, gt_delta,
                                   delta_weight)
            cls_loss = losses['cls_loss']
            loc_loss = losses['loc_loss']
            loss = losses['total_loss']

            if is_valid_number(loss.item()):
                optimizer.zero_grad()
                loss.backward()
                reduce_gradients(model)
                if get_rank() == 0 and cfg.TRAIN.LOG_GRAD:
                    log_grads(model.module, tb_writer, iter)
                clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP)
                optimizer.step()

            batch_time = time.time() - begin
            batch_info = {}
            batch_info['data_time'] = average_reduce(data_time)
            batch_info['batch_time'] = average_reduce(batch_time)
            for k, v in losses.items():
                batch_info[k] = average_reduce(v)
            average_meter.update(**batch_info)
            if get_rank() == 0:
                for k, v in batch_info.items():
                    tb_writer.add_scalar(k, v, iter)
                if iter % cfg.TRAIN.PRINT_EVERY == 0:
                    logger.info(
                        'epoch: {}, iter: {}, cur_lr:{}, cls_loss: {}, loc_loss: {}, loss: {}'
                        .format(epoch + 1, iter, cur_lr, cls_loss.item(),
                                loc_loss.item(), loss.item()))
                    print_speed(iter + 1 + start_epoch * num_per_epoch,
                                average_meter.batch_time.avg,
                                cfg.TRAIN.EPOCHS * num_per_epoch)
            iter += 1
        # save model
        if get_rank() == 0:
            state = {
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1
            }
            logger.info('save snapshot to {}/checkpoint_e{}.pth'.format(
                cfg.TRAIN.SNAPSHOT_DIR, epoch + 1))
            torch.save(
                state, '{}/checkpoint_e{}.pth'.format(cfg.TRAIN.SNAPSHOT_DIR,
                                                      epoch + 1))
Exemple #8
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # seed
    set_random_seed()

    # for universally slimmable networks only
    if getattr(FLAGS, 'universally_slimmable_training', False):
        if getattr(FLAGS, 'test_only', False):
            if getattr(FLAGS, 'width_mult_list_test', None) is not None:
                FLAGS.test_only = False
                # skip training and goto BN calibration
                FLAGS.skip_training = True
        else:
            FLAGS.width_mult_list = FLAGS.width_mult_range

    # model
    model, model_wrapper = get_model()
    if getattr(FLAGS, 'label_smoothing', 0):
        criterion = CrossEntropyLossSmooth(reduction='none')
    else:
        criterion = torch.nn.CrossEntropyLoss(reduction='none')
    if getattr(FLAGS, 'inplace_distill', False):
        soft_criterion = CrossEntropyLossSoft(reduction='none')
    else:
        soft_criterion = None

    # check pretrained
    if getattr(FLAGS, 'pretrained', False):
        checkpoint = torch.load(
            FLAGS.pretrained, map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                print('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        print('Loaded model {}.'.format(FLAGS.pretrained))

    optimizer = get_optimizer(model_wrapper)

    # check resume training
    if os.path.exists(os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt')):
        checkpoint = torch.load(
            os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'),
            map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = get_lr_scheduler(optimizer)
        lr_scheduler.last_epoch = last_epoch
        best_val = checkpoint['best_val']
        train_meters, val_meters = checkpoint['meters']
        print('Loaded checkpoint {} at epoch {}.'.format(
            FLAGS.log_dir, last_epoch))
    else:
        lr_scheduler = get_lr_scheduler(optimizer)
        last_epoch = lr_scheduler.last_epoch
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        # if start from scratch, print model and do profiling
        print(model_wrapper)
        if getattr(FLAGS, 'profiling', False):
            if 'gpu' in FLAGS.profiling:
                profiling(model, use_cuda=True)
            if 'cpu' in FLAGS.profiling:
                profiling(model, use_cuda=False)
            if getattr(FLAGS, 'profiling_only', False):
                return

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(
        train_transforms, val_transforms, test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    # autoslim only
    if getattr(FLAGS, 'autoslim', False):
        with torch.no_grad():
            slimming(train_loader, model_wrapper, criterion)
        return

    if getattr(FLAGS, 'test_only', False) and (test_loader is not None):
        print('Start testing.')
        test_meters = get_meters('test')
        with torch.no_grad():
            if getattr(FLAGS, 'slimmable_training', False):
                for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
                    model_wrapper.apply(
                        lambda m: setattr(m, 'width_mult', width_mult))
                    run_one_epoch(
                        last_epoch, test_loader, model_wrapper, criterion,
                        optimizer, test_meters, phase='test')
            else:
                run_one_epoch(
                    last_epoch, test_loader, model_wrapper, criterion,
                    optimizer, test_meters, phase='test')
        return

    if getattr(FLAGS, 'nonuniform_diff_seed', False):
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    print('Start training.')
    for epoch in range(last_epoch+1, FLAGS.num_epochs):
        if getattr(FLAGS, 'skip_training', False):
            print('Skip training at epoch: {}'.format(epoch))
            break
        lr_scheduler.step()
        # train
        results = run_one_epoch(
            epoch, train_loader, model_wrapper, criterion, optimizer,
            train_meters, phase='train', soft_criterion=soft_criterion)

        # val
        if val_meters is not None:
            val_meters['best_val'].cache(best_val)
        with torch.no_grad():
            results = run_one_epoch(
                epoch, val_loader, model_wrapper, criterion, optimizer,
                val_meters, phase='val')
        if is_master() and results['top1_error'] < best_val:
            best_val = results['top1_error']
            torch.save(
                {
                    'model': model_wrapper.state_dict(),
                },
                os.path.join(FLAGS.log_dir, 'best_model.pt'))
            print('New best validation top1 error: {:.3f}'.format(best_val))
        # save latest checkpoint
        if is_master():
            torch.save(
                {
                    'model': model_wrapper.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'last_epoch': epoch,
                    'best_val': best_val,
                    'meters': (train_meters, val_meters),
                },
                os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'))

    if getattr(FLAGS, 'calibrate_bn', False):
        if getattr(FLAGS, 'universally_slimmable_training', False):
            # need to rebuild model according to width_mult_list_test
            width_mult_list = FLAGS.width_mult_range.copy()
            for width in FLAGS.width_mult_list_test:
                if width not in FLAGS.width_mult_list:
                    width_mult_list.append(width)
            FLAGS.width_mult_list = width_mult_list
            new_model, new_model_wrapper = get_model()
            profiling(new_model, use_cuda=True)
            new_model_wrapper.load_state_dict(
                model_wrapper.state_dict(), strict=False)
            model_wrapper = new_model_wrapper
        cal_meters = get_meters('cal')
        print('Start calibration.')
        results = run_one_epoch(
            -1, train_loader, model_wrapper, criterion, optimizer,
            cal_meters, phase='cal')
        print('Start validation after calibration.')
        with torch.no_grad():
            results = run_one_epoch(
                -1, val_loader, model_wrapper, criterion, optimizer,
                cal_meters, phase='val')
        if is_master():
            torch.save(
                {
                    'model': model_wrapper.state_dict(),
                },
                os.path.join(FLAGS.log_dir, 'best_model_bn_calibrated.pt'))
    return
Exemple #9
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # init distributed
    if getattr(FLAGS, 'distributed', False):
        init_dist()
    # seed
    if getattr(FLAGS, 'use_diff_seed',
               False) and not getattr(FLAGS, 'stoch_valid', False):
        print('use diff seed is True')
        while not is_initialized():
            print('Waiting for initialization ...')
            time.sleep(5)
        print('Expected seed: {}'.format(
            getattr(FLAGS, 'random_seed', 0) + get_rank()))
        set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank())
    else:
        set_random_seed()

    if getattr(FLAGS, 'adjust_lr', False):
        eta_dict = {
            32: 1.0,
            16: 1.0,
            8: 1.0,
            7: 0.99,
            6: 0.98,
            5: 0.97,
            4: 0.94,
            3: 0.88,
            2: 0.77,
            1: 0.58
        }
        eta = lambda b: eta_dict[b]  # noqa: E731
    else:
        eta = None

    # experiment setting
    experiment_setting = get_experiment_setting()
    mprint('stoch_valid: {}, bn_calib_stoch_valid: {}'.format(
        getattr(FLAGS, 'stoch_valid', False),
        getattr(FLAGS, 'bn_calib_stoch_valid', False)))

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    if getattr(FLAGS, 'profiling_only', False):
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)
        return

    #
    ema_decay = getattr(FLAGS, 'ema_decay', None)
    if ema_decay:
        ema = EMA(ema_decay)
        ema.shadow_register(model_wrapper)
        #for name, param in model.named_parameters():
        #    if param.requires_grad:
        #        ema.register(name, param.data)
        #bn_idx = 0
        #for m in model.modules():
        #    if isinstance(m, nn.BatchNorm2d):
        #        ema.register('bn{}_mean'.format(bn_idx), m.running_mean)
        #        ema.register('bn{}_var'.format(bn_idx), m.running_var)
        #        bn_idx += 1
    else:
        ema = None

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(train_transforms, val_transforms,
                                           test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    log_dir = FLAGS.log_dir
    log_dir = os.path.join(log_dir, experiment_setting)

    # check pretrained
    if FLAGS.pretrained_file:
        pretrained_dir = FLAGS.pretrained_dir
        pretrained_dir = os.path.join(pretrained_dir, experiment_setting)
        pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file)
        checkpoint = torch.load(pretrained_file,
                                map_location=lambda storage, loc: storage)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if getattr(FLAGS, 'pretrained_model_remap_keys', False):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                mprint('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        mprint('Loaded model {}.'.format(pretrained_file))
    optimizer = get_optimizer(model_wrapper)
    cal_meters = get_meters('cal', single_sample=True)
    mprint('Start calibration.')
    run_one_epoch(-1,
                  train_loader,
                  model_wrapper,
                  criterion,
                  optimizer,
                  cal_meters,
                  phase='cal',
                  ema=ema,
                  single_sample=True)
    mprint('Start validation after calibration.')
    with torch.no_grad():
        run_one_epoch(-1,
                      val_loader,
                      model_wrapper,
                      criterion,
                      optimizer,
                      cal_meters,
                      phase='val',
                      ema=ema,
                      single_sample=True)
    return
Exemple #10
0
def train(cfg):
    """
    Train function.
    Args:
        cfg (CfgNode) : configs. Details can be found in
            config.py
    """
    # Set random seed from configs.
    if cfg.RNG_SEED != -1:
        random.seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)
        torch.manual_seed(cfg.RNG_SEED)
        torch.cuda.manual_seed_all(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.NUM_GPUS, os.path.join(cfg.LOG_DIR, "log.txt"))

    # Print config.
    logger.info("Train with config:")
    logger.info(pprint.pformat(cfg))

    # Model for training.
    model = build_model(cfg)
    # Construct te optimizer.
    optimizer = optim.construct_optimizer(model, cfg)

    # Print model statistics.
    if du.is_master_proc(cfg.NUM_GPUS):
        misc.log_model_info(model, cfg, use_train_input=True)

    # Create dataloaders.
    train_loader = loader.construct_loader(cfg, 'train')
    val_loader = loader.construct_loader(cfg, 'val')

    if cfg.SOLVER.MAX_EPOCH != -1:
        max_epoch = cfg.SOLVER.MAX_EPOCH * cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS
        num_steps = max_epoch * len(train_loader)
        cfg.SOLVER.NUM_STEPS = cfg.SOLVER.MAX_EPOCH * len(train_loader)
        cfg.SOLVER.WARMUP_PROPORTION = cfg.SOLVER.WARMUP_EPOCHS / cfg.SOLVER.MAX_EPOCH
    else:
        num_steps = cfg.SOLVER.NUM_STEPS * cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS
        max_epoch = math.ceil(num_steps / len(train_loader))
        cfg.SOLVER.MAX_EPOCH = cfg.SOLVER.NUM_STEPS / len(train_loader)
        cfg.SOLVER.WARMUP_EPOCHS = cfg.SOLVER.MAX_EPOCH * cfg.SOLVER.WARMUP_PROPORTION

    start_epoch = 0
    global_step = 0
    if cfg.TRAIN.CHECKPOINT_FILE_PATH:
        if os.path.isfile(cfg.TRAIN.CHECKPOINT_FILE_PATH):
            logger.info(
                "=> loading checkpoint '{}'".format(
                    cfg.TRAIN.CHECKPOINT_FILE_PATH
                )
            )
            ms = model.module if cfg.NUM_GPUS > 1 else model
            # Load the checkpoint on CPU to avoid GPU mem spike.
            checkpoint = torch.load(
                cfg.TRAIN.CHECKPOINT_FILE_PATH, map_location='cpu'
            )
            start_epoch = checkpoint['epoch']
            ms.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            global_step = checkpoint['epoch'] * len(train_loader)
            logger.info(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    cfg.TRAIN.CHECKPOINT_FILE_PATH,
                    checkpoint['epoch']
                )
            )
    else:
        logger.info("Training with random initialization.")

    # Create meters.
    train_meter = TrainMeter(
        len(train_loader),
        num_steps,
        max_epoch,
        cfg
    )
    val_meter = ValMeter(
        len(val_loader),
        max_epoch,
        cfg
    )

    # Perform the training loop.
    logger.info("Start epoch: {}".format(start_epoch+1))

    cudnn.benchmark = True

    best_epoch, best_top1_err, top5_err, best_map = 0, 100.0, 100.0, 0.0

    for cur_epoch in range(start_epoch, max_epoch):
        is_best_epoch = False
        # Shuffle the dataset.
        # loader.shuffle_dataset(train_loader, cur_epoch)
        # Pretrain for one epoch.
        global_step = train_epoch(
            train_loader,
            model,
            optimizer,
            train_meter,
            cur_epoch,
            global_step,
            num_steps,
            cfg
        )

        if cfg.BN.USE_PRECISE_STATS and len(get_bn_modules(model)) > 0:
            calculate_and_update_precise_bn(
                train_loader, model, cfg.BN.NUM_BATCHES_PRECISE
            )

        if misc.is_eval_epoch(cfg, cur_epoch, max_epoch):
            stats = eval_epoch(val_loader, model, val_meter, cur_epoch, cfg)
            if cfg.DATA.MULTI_LABEL:
                if best_map < float(stats["map"]):
                    best_epoch = cur_epoch + 1
                    best_map = float(stats["map"])
                    is_best_epoch = True
                logger.info(
                    "BEST: epoch: {}, best_map: {:.2f}".format(
                        best_epoch, best_map,
                    )
                )
            else:
                if best_top1_err > float(stats["top1_err"]):
                    best_epoch = cur_epoch + 1
                    best_top1_err = float(stats["top1_err"])
                    top5_err = float(stats["top5_err"])
                    is_best_epoch = True
                logger.info(
                    "BEST: epoch: {}, best_top1_err: {:.2f}, top5_err: {:.2f}".format(
                        best_epoch, best_top1_err, top5_err
                    )
                )

        sd = \
            model.module.state_dict() if cfg.NUM_GPUS > 1 else \
            model.state_dict()

        ckpt = {
            'epoch': cur_epoch + 1,
            'model_arch': cfg.MODEL.DOWNSTREAM_ARCH,
            'state_dict': sd,
            'optimizer': optimizer.state_dict(),
        }

        if (cur_epoch + 1) % cfg.SAVE_EVERY_EPOCH == 0 and du.get_rank() == 0:
            sd = \
                model.module.state_dict() if cfg.NUM_GPUS > 1 else \
                model.state_dict()
            save_checkpoint(
                ckpt,
                filename=os.path.join(cfg.SAVE_DIR, f'epoch{cur_epoch+1}.pyth')
            )

        if is_best_epoch and du.get_rank() == 0:
            save_checkpoint(
                ckpt,
                filename=os.path.join(cfg.SAVE_DIR, f"epoch_best.pyth")
            )
def main():
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

    global best_acc1

    run_time = str(datetime.datetime.now())
    cfgs = load_configs(CONFIG_FILE)

    # create log dir and weight dir
    mkdir(cfgs['weight_dir'])
    mkdir(cfgs['log_dir'])

    # create logger
    log_dir = osp.join(cfgs['log_dir'], cfgs['arch'])
    mkdir(log_dir)

    cfgs['log_name'] = cfgs['arch'] + '_' + cfgs['dataset']
    logger = setup_logger(cfgs['log_name'], log_dir, get_rank(),
                          run_time + '.txt')

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())
    logger.info("Loaded configuration file {}".format(CONFIG_FILE))
    logger.info("Running with config:\n{}".format(cfgs))

    #  create model
    logger.info("=> creating model '{}'".format(cfgs['arch']))
    model = models.__dict__[cfgs['arch']]()

    if cfgs['arch'].lower().startswith('wideresnet'):
        # a customized resnet model with last feature map size as 14x14 for better class activation mapping
        model = wideresnet.resnet50(num_classes=cfgs['num_classes'])
    else:
        model = models.__dict__[cfgs['arch']](num_classes=cfgs['num_classes'])

    if cfgs['arch'].lower().startswith(
            'alexnet') or cfgs['arch'].lower().startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()
    logger.info("=> created model '{}'".format(model.__class__.__name__))
    logger.info("model structure: {}".format(model))
    num_gpus = torch.cuda.device_count()
    logger.info("using {} GPUs".format(num_gpus))

    # optionally resume from a checkpoint
    if cfgs['resume']:
        if osp.isfile(cfgs['resume']):
            logger.info("=> loading checkpoint '{}'".format(cfgs['resume']))
            checkpoint = torch.load(cfgs['resume'])
            cfgs['start_epoch'] = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                cfgs['resume'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                cfgs['resume']))

    torch.backends.cudnn.benchmark = True

    # Data loading code
    traindir = osp.join(cfgs['data_path'], 'train')
    valdir = osp.join(cfgs['data_path'], 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=cfgs['batch_size'],
                                               shuffle=True,
                                               num_workers=cfgs['workers'],
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=cfgs['batch_size'],
                                             shuffle=False,
                                             num_workers=cfgs['workers'],
                                             pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                cfgs['lr'],
                                momentum=cfgs['momentum'],
                                weight_decay=float(cfgs['weight_decay']))

    # if cfgs['evaluate']:
    #     validate(val_loader, model, criterion, cfgs)
    #     return

    # for epoch in range(cfgs['start_epoch'], cfgs['epochs']):
    #     adjust_learning_rate(optimizer, epoch, cfgs)

    #     # train for one epoch
    #     train(train_loader, model, criterion, optimizer, epoch, cfgs)

    #     # evaluate on validation set
    #     acc1 = validate(val_loader, model, criterion, cfgs)

    #     # remember best acc@1 and save checkpoint
    #     is_best = acc1 > best_acc1
    #     best_acc1 = max(acc1, best_acc1)
    #     save_checkpoint({
    #         'epoch': epoch + 1,
    #         'arch': cfgs['arch'],
    #         'state_dict': model.state_dict(),
    #         'best_acc1': best_acc1,
    #     }, is_best, cfgs['weight_dir'] + '/' + cfgs['arch'].lower())

    logger.info("start to test the best model")
    best_weight = cfgs['weight_dir'] + '/' + cfgs['arch'].lower(
    ) + '_best.pth.tar'
    if osp.isfile(best_weight):
        logger.info("=> loading best model '{}'".format(best_weight))
        checkpoint = torch.load(best_weight)
        best_acc1 = checkpoint['best_acc1']
        epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        logger.info("=> loaded checkpoint '{}' (val Acc@1 {})".format(
            best_weight, best_acc1))
    else:
        logger.info("=> no best model found at '{}'".format(best_weight))

    acc1 = validate(val_loader, model, criterion, cfgs)