Esempio n. 1
0
def init_logger(config, args):
    makedirs(config.summary_dir)
    makedirs(config.checkpoint_dir)

    # set logger
    path = os.path.dirname(os.path.abspath(__file__))
    path_model = os.path.join(path,
                              'models/base/%s.py' % config.network.lower())
    path_main = os.path.join(path, 'main_prune_non_imagenet.py')
    path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner_file)
    logger = get_logger('log',
                        logpath=config.summary_dir + '/',
                        filepath=path_model,
                        package_files=[path_main, path_pruner])
    logger.info(dict(config))
    summary_writer_path = config.summary_dir + '/' + args.init + '_sp' + str(
        args.target_ratio).replace('.', '_')
    if args.scaled_init:
        summary_writer_path += '_scaled'
    if args.bn:
        summary_writer_path += '_bn'

    if args.sigma_w2 != None and args.init == 'ordered':
        summary_writer_path += '_{}'.format(args.sigma_w2).replace('.', '_')
    summary_writer_path += '_' + args.act + '_' + str(config.depth)

    writer = SummaryWriter(summary_writer_path)
    # sys.stdout = open(os.path.join(config.summary_dir, 'stdout.txt'), 'w+')
    # sys.stderr = open(os.path.join(config.summary_dir, 'stderr.txt'), 'w+')
    return logger, writer
def init_logger(config):
    makedirs(config.summary_dir)
    makedirs(config.checkpoint_dir)

    # set logger
    path = os.path.dirname(os.path.abspath(__file__))
    path_model = os.path.join(path, 'models/base/%s.py' % 'vgg')
    path_main = os.path.join(path, 'main_prune_imagenet.py')
    path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner_file)
    logger = get_logger('log', logpath=config.summary_dir+'/',
                        filepath=path_model, package_files=[path_main, path_pruner])
    logger.info(dict(config))
    writer = SummaryWriter(config.summary_dir)
    return logger, writer
Esempio n. 3
0
def init_logger(config):
    makedirs(config.summary_dir)
    makedirs(config.checkpoint_dir)

    # set logger
    path = os.path.dirname(os.path.abspath(__file__))
    path_model = os.path.join(path, 'models/base/%s.py' % config.network.lower())
    path_main = os.path.join(path, 'main_prune_non_imagenet.py')
    path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner_file)
    logger = get_logger('log', logpath=config.summary_dir + '/',
                        filepath=path_model, package_files=[path_main, path_pruner])
    logger.info(dict(config))
    writer = SummaryWriter(config.summary_dir)
    # sys.stdout = open(os.path.join(config.summary_dir, 'stdout.txt'), 'w+')
    # sys.stderr = open(os.path.join(config.summary_dir, 'stderr.txt'), 'w+')
    return logger, writer
Esempio n. 4
0
def init_summary_writer(config):
    makedirs(config.summary_dir)
    makedirs(config.checkpoint_dir)
    print(config.checkpoint, os.path.exists(config.checkpoint))
    if not os.path.exists(config.checkpoint):
        os.makedirs(config.checkpoint)

    # set logger
    path = os.path.dirname(os.path.abspath(__file__))
    path_model = os.path.join(path, 'models/%s.py' % config.network)
    path_main = os.path.join(path, 'main_prune.py')
    path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner)

    logger = get_logger(f'log{running_time}.log_time',
                        logpath=config.saving_log,
                        filepath=path_model,
                        package_files=[path_main, path_pruner])
    logger.info(dict(config))
    writer = SummaryWriter(config.summary_dir)

    return logger, writer
        stats[it] = stat

        if prune_mode == 'one_pass':
            del net
            del pruner
            net, bottleneck_net = init_network(config, logger, device)
            pruner = init_pruner(net, bottleneck_net, config, writer, logger)
            pruner.iter = it
        with open(os.path.join(config.summary_dir, 'stats.json'), 'w') as f:
            json.dump(stats, f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--tmp_config', type=str, default='', required=False)
    parser.add_argument('--config', type=str, default='', required=False)
    args = parser.parse_args()

    if len(args.tmp_config) > 1:
        print('Using tmp config!')
        config, _ = get_config_from_json(args.tmp_config)
        makedirs(config.summary_dir)
        sys.stdout = open(os.path.join(config.summary_dir, 'stdout.txt'), 'w+')
        sys.stderr = open(os.path.join(config.summary_dir, 'stderr.txt'), 'w+')
        main(config)
    else:
        print('Using config!')
        config = process_config(args.config)
        main(config)
if args.resume:
    print('==> Resuming from checkpoint..')
    assert os.path.isdir(
        'checkpoint/pretrain'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('checkpoint/pretrain/%s_%s%s_bn_best.t7' %
                            (args.dataset, args.network, args.depth))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    print('==> Loaded checkpoint at epoch: %d, acc: %.2f%%' %
          (start_epoch, best_acc))

# init summary writter
log_dir = os.path.join(args.log_dir,
                       '%s_%s%s' % (args.dataset, args.network, args.depth))
makedirs(log_dir)
writer = SummaryWriter(log_dir)


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    lr_scheduler(optimizer, epoch)
    desc = ('[LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (lr_scheduler.get_lr(optimizer), 0, 0, correct, total))

    writer.add_scalar('train/lr', lr_scheduler.get_lr(optimizer), epoch)