Exemplo n.º 1
0
 def get_single_meter(phase, suffix=''):
     meters = {}
     meters['loss'] = ScalarMeter('{}_loss/{}'.format(phase, suffix))
     for k in FLAGS.topk:
         meters['top{}_error'.format(k)] = ScalarMeter(
             '{}_top{}_error/{}'.format(phase, k, suffix))
     return meters
Exemplo n.º 2
0
def get_meters(phase):
    """Util function for meters."""
    meters = {}
    meters['loss'] = ScalarMeter('{}_loss'.format(phase))
    for k in FLAGS.topk:
        meters['top{}_error'.format(k)] = ScalarMeter('{}_top{}_error'.format(
            phase, k))
    return meters
Exemplo n.º 3
0
def get_seg_meters(phase, pruning=False):
    """Util function for meters."""
    meters = {}
    meters['loss'] = ScalarMeter('{}_loss'.format(phase))
    meters['acc'] = ScalarMeter('{}_acc'.format(phase))
    if phase == 'train' and pruning:
        meters['loss_l2'] = ScalarMeter('{}_loss_l2'.format(phase))
        meters['loss_bn_l1'] = ScalarMeter('{}_loss_bn_l1'.format(phase))
    return meters
Exemplo n.º 4
0
 def get_single_meter(phase, suffix=''):
     meters = {}
     meters['loss'] = ScalarMeter('{}_loss/{}'.format(phase, suffix))
     for k in FLAGS.topk:
         meters['top{}_error'.format(k)] = ScalarMeter(
             '{}_top{}_error/{}'.format(phase, k, suffix))
     if phase == 'train':
         meters['lr'] = ScalarMeter('learning_rate')
     return meters
Exemplo n.º 5
0
def get_meters(phase, pruning=False):
    """Util function for meters."""
    meters = {}
    meters['loss'] = ScalarMeter('{}_loss'.format(phase))
    if phase == 'train' and pruning:
        meters['loss_l2'] = ScalarMeter('{}_loss_l2'.format(phase))
        meters['loss_bn_l1'] = ScalarMeter('{}_loss_bn_l1'.format(phase))
    for k in FLAGS.topk:
        meters['top{}_error'.format(k)] = ScalarMeter('{}_top{}_error'.format(
            phase, k))
    return meters
Exemplo n.º 6
0
def get_meters(phase, single_sample=False):
    """util function for meters"""
    def get_single_meter(phase, suffix=''):
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss/{}'.format(phase, suffix))
        for k in FLAGS.topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error/{}'.format(phase, k, suffix))
        return meters

    assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.'
    if single_sample:
        meters = get_single_meter(phase)
    elif getattr(FLAGS, 'quantizable_training', False):
        meters = {}
        for bits in FLAGS.bits_list:
            meters[str(bits)] = get_single_meter(phase, str(bits))
        #if getattr(FLAGS, 'stoch_valid', False):
        #    for sample_idx in range(getattr(FLAGS, 'sample_iter', 8)):
        #        meters[f'sample_{sample_idx}'] = get_single_meter(phase, f'sample_{sample_idx}')
    else:
        meters = get_single_meter(phase)
    if phase == 'val':
        meters['best_val'] = ScalarMeter('best_val')
    return meters
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    lossMeter = ScalarMeter(args.log_interval)
    for batch_idx, (data, target) in enumerate(train_loader):
        #data, target = data.to(device), target.to(device)
        data = data.cuda()
        target = target.cuda()

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if args.gpus > 1:
            [loss] = du.all_reduce([loss])

        if dist.get_rank() == 0:
            lossMeter.add_value(loss.item())

        if batch_idx % args.log_interval == 0 and dist.get_rank() == 0:
            if args.gpus > 1:
                loss = lossMeter.get_win_median()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(data) * args.gpus, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break
Exemplo n.º 8
0
def get_meters(phase):
    """util function for meters"""
    if getattr(FLAGS, 'slimmable_training', False):
        meters_all = {}
        for width_mult in FLAGS.width_mult_list:
            meters = {}
            meters['loss'] = ScalarMeter('{}_loss/{}'.format(
                phase, str(width_mult)))
            for k in FLAGS.topk:
                meters['top{}_error'.format(k)] = ScalarMeter(
                    '{}_top{}_error/{}'.format(phase, k, str(width_mult)))
            meters_all[str(width_mult)] = meters
        meters = meters_all
    else:
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss'.format(phase))
        for k in FLAGS.topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error'.format(phase, k))
    return meters
Exemplo n.º 9
0
def get_meters(phase, topk, width_mult_list, slimmable=True):
    """ Util function for meters """
    if slimmable:
        meters_all = {}
        for width_mult in width_mult_list:
            meters = {}
            meters['loss'] = ScalarMeter('{}_loss/{}'.format(
                phase, str(width_mult)))
            for k in topk:
                meters['top{}_error'.format(k)] = ScalarMeter(
                    '{}_top{}_error/{}'.format(phase, k, str(width_mult)))
            meters_all[str(width_mult)] = meters
        meters = meters_all
    else:
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss'.format(phase))
        for k in topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error'.format(phase, k))
    return meters
Exemplo n.º 10
0
def get_meters(phase):
    """util function for meters"""
    def get_single_meter(phase, suffix=''):
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss/{}'.format(phase, suffix))
        for k in FLAGS.topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error/{}'.format(phase, k, suffix))
        return meters

    assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.'
    meters = get_single_meter(phase)
    if phase == 'val':
        meters['best_val'] = ScalarMeter('best_val')
    return meters
Exemplo n.º 11
0
def get_meters(phase, single_sample=False):
    """util function for meters"""
    def get_single_meter(phase, suffix=''):
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss/{}'.format(phase, suffix))
        for k in FLAGS.topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error/{}'.format(phase, k, suffix))
        return meters

    assert phase in ['train', 'val', 'test'], 'Invalid phase.'
    if single_sample:
        meters = get_single_meter(phase)
    elif getattr(FLAGS, 'adaptive_training', False):
        meters = {}
        for bits in FLAGS.bits_list:
            meters[str(bits)] = get_single_meter(phase, str(bits))
    else:
        meters = get_single_meter(phase)
    if phase == 'val':
        meters['best_val'] = ScalarMeter('best_val')
    return meters
Exemplo n.º 12
0
def get_meters(phase):
    """util function for meters"""
    def get_single_meter(phase, suffix=''):
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss/{}'.format(phase, suffix))
        for k in FLAGS.topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error/{}'.format(phase, k, suffix))
        if phase == 'train':
            meters['lr'] = ScalarMeter('learning_rate')
        return meters

    assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.'
    if getattr(FLAGS, 'slimmable_training', False):
        meters = {}
        for width_mult in FLAGS.width_mult_list:
            meters[str(width_mult)] = get_single_meter(phase, str(width_mult))
    else:
        meters = get_single_meter(phase)
    if phase == 'val':
        meters['best_val'] = ScalarMeter('best_val')
    return meters
Exemplo n.º 13
0
def train_val_test():
    """train and val"""
    torch.backends.cudnn.benchmark = True
    # seed
    set_random_seed()

    # model
    model = get_model()
    model_wrapper = torch.nn.DataParallel(model).cuda()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained)
        # update keys from external models
        if type(checkpoint) == dict and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys') and
                FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                print('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        print('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = get_optimizer(model_wrapper)
    # check resume training
    if os.path.exists(os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt')):
        checkpoint = torch.load(
            os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'))
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = get_lr_scheduler(optimizer)
        lr_scheduler.last_epoch = last_epoch
        best_val = checkpoint['best_val']
        train_meters, val_meters = checkpoint['meters']
        print('Loaded checkpoint {} at epoch {}.'.format(
            FLAGS.log_dir, last_epoch))
    else:
        lr_scheduler = get_lr_scheduler(optimizer)
        last_epoch = lr_scheduler.last_epoch
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        val_meters['best_val'] = ScalarMeter('best_val')
        # if start from scratch, print model and do profiling
        print(model_wrapper)
        if FLAGS.profiling:
            if 'gpu' in FLAGS.profiling:
                profiling(model, use_cuda=True)
            if 'cpu' in FLAGS.profiling:
                profiling(model, use_cuda=False)

    # data
    train_transforms, val_transforms, test_transforms = data_transforms()
    train_set, val_set, test_set = dataset(
        train_transforms, val_transforms, test_transforms)
    train_loader, val_loader, test_loader = data_loader(
        train_set, val_set, test_set)

    if FLAGS.test_only and (test_loader is not None):
        print('Start testing.')
        test_meters = get_meters('test')
        with torch.no_grad():
            run_one_epoch(
                last_epoch, test_loader, model_wrapper, criterion, optimizer,
                test_meters, phase='test')
        return

    print('Start training.')
    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        lr_scheduler.step()
        # train
        results = run_one_epoch(
            epoch, train_loader, model_wrapper, criterion, optimizer,
            train_meters, phase='train')

        # val
        val_meters['best_val'].cache(best_val)
        with torch.no_grad():
            results = run_one_epoch(
                epoch, val_loader, model_wrapper, criterion, optimizer,
                val_meters, phase='val')
        if results['top1_error'] < best_val:
            best_val = results['top1_error']
            torch.save(
                {
                    'model': model_wrapper.state_dict(),
                },
                os.path.join(FLAGS.log_dir, 'best_model.pt'))
            print('New best validation top1 error: {:.3f}'.format(best_val))

        # save latest checkpoint
        torch.save(
            {
                'model': model_wrapper.state_dict(),
                'optimizer': optimizer.state_dict(),
                'last_epoch': epoch,
                'best_val': best_val,
                'meters': (train_meters, val_meters),
            },
            os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'))
    return
Exemplo n.º 14
0
def run_one_experiment():
    t_exp_start = time.time()

    # Save all print-out to a logger file
    logger = Logger(FLAGS.log_file)

    # Print experience setup
    for k in sorted(FLAGS.keys()):
        print('{}: {}'.format(k, FLAGS[k]))

    # Init torch
    if FLAGS.seed is None:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
    else:
        random.seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        torch.manual_seed(FLAGS.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    # Init model
    model = importlib.import_module(FLAGS.module_name).get_model(FLAGS)
    model = torch.nn.DataParallel(model).cuda()

    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained)
        model.module.load_state_dict(checkpoint['model'])
        print('Loaded model {}.'.format(FLAGS.pretrained))

    if FLAGS.model_profiling and len(FLAGS.model_profiling) > 0:
        print(model)
        profiling(model, FLAGS.model_profiling, FLAGS.image_size,
                  FLAGS.image_channels, FLAGS.train_width_mults,
                  FLAGS.model_profiling_verbose)
    logger.flush()

    # Init data loaders
    train_loader, val_loader, _, train_set = prepare_data(
        FLAGS.dataset, FLAGS.data_dir, FLAGS.data_transforms,
        FLAGS.data_loader, FLAGS.data_loader_workers, FLAGS.train_batch_size,
        FLAGS.val_batch_size, FLAGS.drop_last, FLAGS.test_only)
    class_labels = train_set.classes

    # Perform inference/test only
    if FLAGS.test_only:
        print('Start testing...')
        min_wm = min(FLAGS.train_width_mults)
        max_wm = max(FLAGS.train_width_mults)
        if FLAGS.test_num_width_mults == 1:
            test_width_mults = []
        else:
            step = (max_wm - min_wm) / (FLAGS.test_num_width_mults - 1)
            test_width_mults = np.arange(min_wm, max_wm, step).tolist()
        test_width_mults += [max_wm]

        criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
        test_meters = get_meters('val', FLAGS.topk, test_width_mults)
        epoch = -1

        avg_error1, _ = test(epoch,
                             val_loader,
                             model,
                             criterion,
                             test_meters,
                             test_width_mults,
                             topk=FLAGS.topk)
        print('==> Epoch avg accuracy {:.2f}%,'.format((1 - avg_error1) * 100))

        logger.close()
        plot_acc_width(FLAGS.log_file)
        return

    # Init training devices
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    optimizer = get_optimizer(model,
                              FLAGS.optimizer,
                              FLAGS.weight_decay,
                              FLAGS.lr,
                              FLAGS.momentum,
                              FLAGS.nesterov,
                              depthwise=FLAGS.depthwise)
    lr_scheduler = get_lr_scheduler(optimizer, FLAGS.lr_scheduler,
                                    FLAGS.lr_scheduler_params)

    train_meters = get_meters('train', FLAGS.topk, FLAGS.train_width_mults)
    val_meters = get_meters('val', FLAGS.topk, FLAGS.train_width_mults)
    val_meters['best_val_error1'] = ScalarMeter('best_val_error1')

    time_meter = ScalarMeter('runtime')

    # Perform training
    print('Start training...')
    last_epoch = -1
    best_val_error1 = 1.
    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        t_epoch_start = time.time()
        print('\nEpoch {}/{}.'.format(epoch + 1, FLAGS.num_epochs) +
              ' Print format: [width factor, loss, accuracy].' +
              ' Learning rate: {}'.format(optimizer.param_groups[0]['lr']))

        # Train one epoch
        steps_per_epoch = len(train_loader.dataset) / FLAGS.train_batch_size
        total_steps = FLAGS.num_epochs * steps_per_epoch
        lr_decay_per_step = (None if FLAGS.lr_scheduler != 'linear_decaying'
                             else FLAGS.lr / total_steps)
        if FLAGS.lr_scheduler == 'linear_decaying':
            lr_decay_per_step = (FLAGS.lr / FLAGS.num_epochs /
                                 len(train_loader.dataset) *
                                 FLAGS.train_batch_size)
        train_results = train(epoch, FLAGS.num_epochs, train_loader, model,
                              criterion, optimizer, train_meters,
                              FLAGS.train_width_mults, FLAGS.log_interval,
                              FLAGS.topk, FLAGS.rand_width_mult_args,
                              lr_decay_per_step)

        # Validate
        avg_error1, val_results = test(epoch,
                                       val_loader,
                                       model,
                                       criterion,
                                       val_meters,
                                       FLAGS.train_width_mults,
                                       topk=FLAGS.topk)

        # Update best result
        is_best = avg_error1 < best_val_error1
        if is_best:
            best_val_error1 = avg_error1
        val_meters['best_val_error1'].cache(best_val_error1)

        # Save checkpoint
        print()
        if FLAGS.saving_checkpoint:
            save_model(model, optimizer, epoch, FLAGS.train_width_mults,
                       FLAGS.rand_width_mult_args, train_meters, val_meters,
                       1 - avg_error1, 1 - best_val_error1,
                       FLAGS.epoch_checkpoint, is_best, FLAGS.best_checkpoint)
        print('==> Epoch avg accuracy {:.2f}%,'.format((1 - avg_error1) * 100),
              'Best accuracy: {:.2f}%\n'.format((1 - best_val_error1) * 100))

        logger.flush()

        if lr_scheduler is not None and epoch != FLAGS.num_epochs - 1:
            lr_scheduler.step()
        print('Epoch time: {:.4f} mins'.format(
            (time.time() - t_epoch_start) / 60))

    print('Total time: {:.4f} mins'.format((time.time() - t_exp_start) / 60))
    logger.close()
    return