Exemplo n.º 1
0
def run_one_experiment():
    t_exp_start = time.time()

    # Save all print-out to a logger file
    logger = Logger(FLAGS.log_file)

    # Print experience setup
    for k in sorted(FLAGS.keys()):
        print('{}: {}'.format(k, FLAGS[k]))

    # Init torch
    if FLAGS.seed is None:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
    else:
        random.seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        torch.manual_seed(FLAGS.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    # Init model
    model = importlib.import_module(FLAGS.module_name).get_model(FLAGS)
    model = torch.nn.DataParallel(model).cuda()

    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained)
        model.module.load_state_dict(checkpoint['model'])
        print('Loaded model {}.'.format(FLAGS.pretrained))

    if FLAGS.model_profiling and len(FLAGS.model_profiling) > 0:
        print(model)
        profiling(model, FLAGS.model_profiling, FLAGS.image_size,
                  FLAGS.image_channels, FLAGS.train_width_mults,
                  FLAGS.model_profiling_verbose)
    logger.flush()

    # Init data loaders
    train_loader, val_loader, _, train_set = prepare_data(
        FLAGS.dataset, FLAGS.data_dir, FLAGS.data_transforms,
        FLAGS.data_loader, FLAGS.data_loader_workers, FLAGS.train_batch_size,
        FLAGS.val_batch_size, FLAGS.drop_last, FLAGS.test_only)
    class_labels = train_set.classes

    # Perform inference/test only
    if FLAGS.test_only:
        print('Start testing...')
        min_wm = min(FLAGS.train_width_mults)
        max_wm = max(FLAGS.train_width_mults)
        if FLAGS.test_num_width_mults == 1:
            test_width_mults = []
        else:
            step = (max_wm - min_wm) / (FLAGS.test_num_width_mults - 1)
            test_width_mults = np.arange(min_wm, max_wm, step).tolist()
        test_width_mults += [max_wm]

        criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
        test_meters = get_meters('val', FLAGS.topk, test_width_mults)
        epoch = -1

        avg_error1, _ = test(epoch,
                             val_loader,
                             model,
                             criterion,
                             test_meters,
                             test_width_mults,
                             topk=FLAGS.topk)
        print('==> Epoch avg accuracy {:.2f}%,'.format((1 - avg_error1) * 100))

        logger.close()
        plot_acc_width(FLAGS.log_file)
        return

    # Init training devices
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    optimizer = get_optimizer(model,
                              FLAGS.optimizer,
                              FLAGS.weight_decay,
                              FLAGS.lr,
                              FLAGS.momentum,
                              FLAGS.nesterov,
                              depthwise=FLAGS.depthwise)
    lr_scheduler = get_lr_scheduler(optimizer, FLAGS.lr_scheduler,
                                    FLAGS.lr_scheduler_params)

    train_meters = get_meters('train', FLAGS.topk, FLAGS.train_width_mults)
    val_meters = get_meters('val', FLAGS.topk, FLAGS.train_width_mults)
    val_meters['best_val_error1'] = ScalarMeter('best_val_error1')

    time_meter = ScalarMeter('runtime')

    # Perform training
    print('Start training...')
    last_epoch = -1
    best_val_error1 = 1.
    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        t_epoch_start = time.time()
        print('\nEpoch {}/{}.'.format(epoch + 1, FLAGS.num_epochs) +
              ' Print format: [width factor, loss, accuracy].' +
              ' Learning rate: {}'.format(optimizer.param_groups[0]['lr']))

        # Train one epoch
        steps_per_epoch = len(train_loader.dataset) / FLAGS.train_batch_size
        total_steps = FLAGS.num_epochs * steps_per_epoch
        lr_decay_per_step = (None if FLAGS.lr_scheduler != 'linear_decaying'
                             else FLAGS.lr / total_steps)
        if FLAGS.lr_scheduler == 'linear_decaying':
            lr_decay_per_step = (FLAGS.lr / FLAGS.num_epochs /
                                 len(train_loader.dataset) *
                                 FLAGS.train_batch_size)
        train_results = train(epoch, FLAGS.num_epochs, train_loader, model,
                              criterion, optimizer, train_meters,
                              FLAGS.train_width_mults, FLAGS.log_interval,
                              FLAGS.topk, FLAGS.rand_width_mult_args,
                              lr_decay_per_step)

        # Validate
        avg_error1, val_results = test(epoch,
                                       val_loader,
                                       model,
                                       criterion,
                                       val_meters,
                                       FLAGS.train_width_mults,
                                       topk=FLAGS.topk)

        # Update best result
        is_best = avg_error1 < best_val_error1
        if is_best:
            best_val_error1 = avg_error1
        val_meters['best_val_error1'].cache(best_val_error1)

        # Save checkpoint
        print()
        if FLAGS.saving_checkpoint:
            save_model(model, optimizer, epoch, FLAGS.train_width_mults,
                       FLAGS.rand_width_mult_args, train_meters, val_meters,
                       1 - avg_error1, 1 - best_val_error1,
                       FLAGS.epoch_checkpoint, is_best, FLAGS.best_checkpoint)
        print('==> Epoch avg accuracy {:.2f}%,'.format((1 - avg_error1) * 100),
              'Best accuracy: {:.2f}%\n'.format((1 - best_val_error1) * 100))

        logger.flush()

        if lr_scheduler is not None and epoch != FLAGS.num_epochs - 1:
            lr_scheduler.step()
        print('Epoch time: {:.4f} mins'.format(
            (time.time() - t_epoch_start) / 60))

    print('Total time: {:.4f} mins'.format((time.time() - t_exp_start) / 60))
    logger.close()
    return