def init_logger(config, args): makedirs(config.summary_dir) makedirs(config.checkpoint_dir) # set logger path = os.path.dirname(os.path.abspath(__file__)) path_model = os.path.join(path, 'models/base/%s.py' % config.network.lower()) path_main = os.path.join(path, 'main_prune_non_imagenet.py') path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner_file) logger = get_logger('log', logpath=config.summary_dir + '/', filepath=path_model, package_files=[path_main, path_pruner]) logger.info(dict(config)) summary_writer_path = config.summary_dir + '/' + args.init + '_sp' + str( args.target_ratio).replace('.', '_') if args.scaled_init: summary_writer_path += '_scaled' if args.bn: summary_writer_path += '_bn' if args.sigma_w2 != None and args.init == 'ordered': summary_writer_path += '_{}'.format(args.sigma_w2).replace('.', '_') summary_writer_path += '_' + args.act + '_' + str(config.depth) writer = SummaryWriter(summary_writer_path) # sys.stdout = open(os.path.join(config.summary_dir, 'stdout.txt'), 'w+') # sys.stderr = open(os.path.join(config.summary_dir, 'stderr.txt'), 'w+') return logger, writer
def init_logger(config): makedirs(config.summary_dir) makedirs(config.checkpoint_dir) # set logger path = os.path.dirname(os.path.abspath(__file__)) path_model = os.path.join(path, 'models/base/%s.py' % 'vgg') path_main = os.path.join(path, 'main_prune_imagenet.py') path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner_file) logger = get_logger('log', logpath=config.summary_dir+'/', filepath=path_model, package_files=[path_main, path_pruner]) logger.info(dict(config)) writer = SummaryWriter(config.summary_dir) return logger, writer
def init_logger(config): makedirs(config.summary_dir) makedirs(config.checkpoint_dir) # set logger path = os.path.dirname(os.path.abspath(__file__)) path_model = os.path.join(path, 'models/base/%s.py' % config.network.lower()) path_main = os.path.join(path, 'main_prune_non_imagenet.py') path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner_file) logger = get_logger('log', logpath=config.summary_dir + '/', filepath=path_model, package_files=[path_main, path_pruner]) logger.info(dict(config)) writer = SummaryWriter(config.summary_dir) # sys.stdout = open(os.path.join(config.summary_dir, 'stdout.txt'), 'w+') # sys.stderr = open(os.path.join(config.summary_dir, 'stderr.txt'), 'w+') return logger, writer
def init_summary_writer(config): makedirs(config.summary_dir) makedirs(config.checkpoint_dir) print(config.checkpoint, os.path.exists(config.checkpoint)) if not os.path.exists(config.checkpoint): os.makedirs(config.checkpoint) # set logger path = os.path.dirname(os.path.abspath(__file__)) path_model = os.path.join(path, 'models/%s.py' % config.network) path_main = os.path.join(path, 'main_prune.py') path_pruner = os.path.join(path, 'pruner/%s.py' % config.pruner) logger = get_logger(f'log{running_time}.log_time', logpath=config.saving_log, filepath=path_model, package_files=[path_main, path_pruner]) logger.info(dict(config)) writer = SummaryWriter(config.summary_dir) return logger, writer
stats[it] = stat if prune_mode == 'one_pass': del net del pruner net, bottleneck_net = init_network(config, logger, device) pruner = init_pruner(net, bottleneck_net, config, writer, logger) pruner.iter = it with open(os.path.join(config.summary_dir, 'stats.json'), 'w') as f: json.dump(stats, f) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--tmp_config', type=str, default='', required=False) parser.add_argument('--config', type=str, default='', required=False) args = parser.parse_args() if len(args.tmp_config) > 1: print('Using tmp config!') config, _ = get_config_from_json(args.tmp_config) makedirs(config.summary_dir) sys.stdout = open(os.path.join(config.summary_dir, 'stdout.txt'), 'w+') sys.stderr = open(os.path.join(config.summary_dir, 'stderr.txt'), 'w+') main(config) else: print('Using config!') config = process_config(args.config) main(config)
if args.resume: print('==> Resuming from checkpoint..') assert os.path.isdir( 'checkpoint/pretrain'), 'Error: no checkpoint directory found!' checkpoint = torch.load('checkpoint/pretrain/%s_%s%s_bn_best.t7' % (args.dataset, args.network, args.depth)) net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] print('==> Loaded checkpoint at epoch: %d, acc: %.2f%%' % (start_epoch, best_acc)) # init summary writter log_dir = os.path.join(args.log_dir, '%s_%s%s' % (args.dataset, args.network, args.depth)) makedirs(log_dir) writer = SummaryWriter(log_dir) def train(epoch): print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 lr_scheduler(optimizer, epoch) desc = ('[LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (lr_scheduler.get_lr(optimizer), 0, 0, correct, total)) writer.add_scalar('train/lr', lr_scheduler.get_lr(optimizer), epoch)