def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    lossMeter = ScalarMeter(args.log_interval)
    for batch_idx, (data, target) in enumerate(train_loader):
        #data, target = data.to(device), target.to(device)
        data = data.cuda()
        target = target.cuda()

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if args.gpus > 1:
            [loss] = du.all_reduce([loss])

        if dist.get_rank() == 0:
            lossMeter.add_value(loss.item())

        if batch_idx % args.log_interval == 0 and dist.get_rank() == 0:
            if args.gpus > 1:
                loss = lossMeter.get_win_median()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(data) * args.gpus, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break
Exemple #2
0
def get_meters(phase, topk, width_mult_list, slimmable=True):
    """ Util function for meters """
    if slimmable:
        meters_all = {}
        for width_mult in width_mult_list:
            meters = {}
            meters['loss'] = ScalarMeter('{}_loss/{}'.format(
                phase, str(width_mult)))
            for k in topk:
                meters['top{}_error'.format(k)] = ScalarMeter(
                    '{}_top{}_error/{}'.format(phase, k, str(width_mult)))
            meters_all[str(width_mult)] = meters
        meters = meters_all
    else:
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss'.format(phase))
        for k in topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error'.format(phase, k))
    return meters
Exemple #3
0
def get_meters(phase):
    """util function for meters"""
    def get_single_meter(phase, suffix=''):
        meters = {}
        meters['loss'] = ScalarMeter('{}_loss/{}'.format(phase, suffix))
        for k in FLAGS.topk:
            meters['top{}_error'.format(k)] = ScalarMeter(
                '{}_top{}_error/{}'.format(phase, k, suffix))
        if phase == 'train':
            meters['lr'] = ScalarMeter('learning_rate')
        return meters

    assert phase in ['train', 'val', 'test', 'cal'], 'Invalid phase.'
    if getattr(FLAGS, 'slimmable_training', False):
        meters = {}
        for width_mult in FLAGS.width_mult_list:
            meters[str(width_mult)] = get_single_meter(phase, str(width_mult))
    else:
        meters = get_single_meter(phase)
    if phase == 'val':
        meters['best_val'] = ScalarMeter('best_val')
    return meters
Exemple #4
0
def run_one_experiment():
    t_exp_start = time.time()

    # Save all print-out to a logger file
    logger = Logger(FLAGS.log_file)

    # Print experience setup
    for k in sorted(FLAGS.keys()):
        print('{}: {}'.format(k, FLAGS[k]))

    # Init torch
    if FLAGS.seed is None:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
    else:
        random.seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        torch.manual_seed(FLAGS.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    # Init model
    model = importlib.import_module(FLAGS.module_name).get_model(FLAGS)
    model = torch.nn.DataParallel(model).cuda()

    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained)
        model.module.load_state_dict(checkpoint['model'])
        print('Loaded model {}.'.format(FLAGS.pretrained))

    if FLAGS.model_profiling and len(FLAGS.model_profiling) > 0:
        print(model)
        profiling(model, FLAGS.model_profiling, FLAGS.image_size,
                  FLAGS.image_channels, FLAGS.train_width_mults,
                  FLAGS.model_profiling_verbose)
    logger.flush()

    # Init data loaders
    train_loader, val_loader, _, train_set = prepare_data(
        FLAGS.dataset, FLAGS.data_dir, FLAGS.data_transforms,
        FLAGS.data_loader, FLAGS.data_loader_workers, FLAGS.train_batch_size,
        FLAGS.val_batch_size, FLAGS.drop_last, FLAGS.test_only)
    class_labels = train_set.classes

    # Perform inference/test only
    if FLAGS.test_only:
        print('Start testing...')
        min_wm = min(FLAGS.train_width_mults)
        max_wm = max(FLAGS.train_width_mults)
        if FLAGS.test_num_width_mults == 1:
            test_width_mults = []
        else:
            step = (max_wm - min_wm) / (FLAGS.test_num_width_mults - 1)
            test_width_mults = np.arange(min_wm, max_wm, step).tolist()
        test_width_mults += [max_wm]

        criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
        test_meters = get_meters('val', FLAGS.topk, test_width_mults)
        epoch = -1

        avg_error1, _ = test(epoch,
                             val_loader,
                             model,
                             criterion,
                             test_meters,
                             test_width_mults,
                             topk=FLAGS.topk)
        print('==> Epoch avg accuracy {:.2f}%,'.format((1 - avg_error1) * 100))

        logger.close()
        plot_acc_width(FLAGS.log_file)
        return

    # Init training devices
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    optimizer = get_optimizer(model,
                              FLAGS.optimizer,
                              FLAGS.weight_decay,
                              FLAGS.lr,
                              FLAGS.momentum,
                              FLAGS.nesterov,
                              depthwise=FLAGS.depthwise)
    lr_scheduler = get_lr_scheduler(optimizer, FLAGS.lr_scheduler,
                                    FLAGS.lr_scheduler_params)

    train_meters = get_meters('train', FLAGS.topk, FLAGS.train_width_mults)
    val_meters = get_meters('val', FLAGS.topk, FLAGS.train_width_mults)
    val_meters['best_val_error1'] = ScalarMeter('best_val_error1')

    time_meter = ScalarMeter('runtime')

    # Perform training
    print('Start training...')
    last_epoch = -1
    best_val_error1 = 1.
    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        t_epoch_start = time.time()
        print('\nEpoch {}/{}.'.format(epoch + 1, FLAGS.num_epochs) +
              ' Print format: [width factor, loss, accuracy].' +
              ' Learning rate: {}'.format(optimizer.param_groups[0]['lr']))

        # Train one epoch
        steps_per_epoch = len(train_loader.dataset) / FLAGS.train_batch_size
        total_steps = FLAGS.num_epochs * steps_per_epoch
        lr_decay_per_step = (None if FLAGS.lr_scheduler != 'linear_decaying'
                             else FLAGS.lr / total_steps)
        if FLAGS.lr_scheduler == 'linear_decaying':
            lr_decay_per_step = (FLAGS.lr / FLAGS.num_epochs /
                                 len(train_loader.dataset) *
                                 FLAGS.train_batch_size)
        train_results = train(epoch, FLAGS.num_epochs, train_loader, model,
                              criterion, optimizer, train_meters,
                              FLAGS.train_width_mults, FLAGS.log_interval,
                              FLAGS.topk, FLAGS.rand_width_mult_args,
                              lr_decay_per_step)

        # Validate
        avg_error1, val_results = test(epoch,
                                       val_loader,
                                       model,
                                       criterion,
                                       val_meters,
                                       FLAGS.train_width_mults,
                                       topk=FLAGS.topk)

        # Update best result
        is_best = avg_error1 < best_val_error1
        if is_best:
            best_val_error1 = avg_error1
        val_meters['best_val_error1'].cache(best_val_error1)

        # Save checkpoint
        print()
        if FLAGS.saving_checkpoint:
            save_model(model, optimizer, epoch, FLAGS.train_width_mults,
                       FLAGS.rand_width_mult_args, train_meters, val_meters,
                       1 - avg_error1, 1 - best_val_error1,
                       FLAGS.epoch_checkpoint, is_best, FLAGS.best_checkpoint)
        print('==> Epoch avg accuracy {:.2f}%,'.format((1 - avg_error1) * 100),
              'Best accuracy: {:.2f}%\n'.format((1 - best_val_error1) * 100))

        logger.flush()

        if lr_scheduler is not None and epoch != FLAGS.num_epochs - 1:
            lr_scheduler.step()
        print('Epoch time: {:.4f} mins'.format(
            (time.time() - t_epoch_start) / 60))

    print('Total time: {:.4f} mins'.format((time.time() - t_exp_start) / 60))
    logger.close()
    return