def main(): global args, best_acc, tb_writer, logger args = parser.parse_args() init_log('global', logging.INFO) # 返回一个logger对象,logging_INFO是日志的等级 if args.log != "": add_file_handler('global', args.log, logging.INFO) logger = logging.getLogger('global') # 获取上面初始化的logger对象 logger.info("\n" + collect_env_info()) logger.info(args) cfg = load_config(args) # 返回修改后的配置文件对象 logger.info("config \n{}".format(json.dumps(cfg, indent=4))) #json.loads()是将str转化成dict格式,json.dumps()是将dict转化成str格式。 if args.log_dir: tb_writer = SummaryWriter(args.log_dir) else: tb_writer = Dummy() # build dataset train_loader, val_loader = build_data_loader(cfg) if args.arch == 'Custom': from custom import Custom model = Custom(pretrain=True, anchors=cfg['anchors']) else: exit() logger.info(model) if args.pretrained: model = load_pretrain(model, args.pretrained) model = model.cuda() dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda() if args.resume and args.start_epoch != 0: model.features.unfix((args.start_epoch - 1) / args.epochs) optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch) # optionally resume from a checkpoint if args.resume: assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume) model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume) dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda() logger.info(lr_scheduler) logger.info('model prepare done') train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
def main(): global args, best_acc, tb_writer, logger args = parser.parse_args() args = args_process(args) init_log('global', logging.INFO) if args.log != "": add_file_handler('global', args.log, logging.INFO) logger = logging.getLogger('global') logger.info("\n" + collect_env_info()) logger.info(args) cfg = load_config(args) logger.info("config \n{}".format(json.dumps(cfg, indent=4))) # build dataset train_loader, val_loader = build_data_loader(cfg) args.img_size = int(cfg['train_datasets']['search_size']) args.nms_threshold = float(cfg['train_datasets']['RPN_NMS']) if args.arch == 'Custom': from custom import Custom model = Custom(pretrain=True, opts=args, anchors=train_loader.dataset.anchors) else: exit() logger.info(model) if args.pretrained: model = load_pretrain(model, args.pretrained) else: raise Exception("Pretrained weights must be loaded!") model = model.cuda() dist_model = torch.nn.DataParallel(model, list(range( torch.cuda.device_count()))).cuda() logger.info('model prepare done') logger = logging.getLogger('global') val_avg = AverageMeter() validation(val_loader, dist_model, cfg, val_avg)
def main(): global args, best_acc, tb_writer, logger args = parser.parse_args() args = args_process(args) init_log('global', logging.INFO) if args.log != "": add_file_handler('global', args.log, logging.INFO) logger = logging.getLogger('global') logger.info("\n" + collect_env_info()) logger.info(args) cfg = load_config(args) logger.info("config \n{}".format(json.dumps(cfg, indent=4))) if args.log_dir: tb_writer = SummaryWriter(args.log_dir) else: tb_writer = Dummy() # build dataset train_loader, val_loader = build_data_loader(cfg) args.img_size = int(cfg['train_datasets']['search_size']) args.nms_threshold = float(cfg['train_datasets']['RPN_NMS']) if args.arch == 'Custom': from custom import Custom model = Custom(pretrain=True, opts=args, anchors=cfg['anchors']) else: exit() logger.info(model) if args.pretrained: model = load_pretrain(model, args.pretrained) model = model.cuda() dist_model = torch.nn.DataParallel(model, list(range( torch.cuda.device_count()))).cuda() if args.resume and args.start_epoch != 0: model.features.unfix((args.start_epoch - 1) / args.epochs) optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch) # optionally resume from a checkpoint if args.resume: assert os.path.isfile(args.resume), '{} is not a valid file'.format( args.resume) model, optimizer, args.start_epoch, best_acc, arch = restore_from( model, optimizer, args.resume) dist_model = torch.nn.DataParallel( model, list(range(torch.cuda.device_count()))).cuda() logger.info(lr_scheduler) logger.info('model prepare done') train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
def main(): global args, best_acc, tb_writer, logger args = parser.parse_args() init_log('global', logging.INFO) if args.log != "": add_file_handler('global', args.log, logging.INFO) print("Init logger") logger = logging.getLogger('global') print(44) #logger.info("\n" + collect_env_info()) print(99) logger.info(args) cfg = load_config(args) logger.info("config \n{}".format(json.dumps(cfg, indent=4))) print(2) if args.log_dir: tb_writer = SummaryWriter(args.log_dir) else: tb_writer = Dummy() # build dataset train_loader, val_loader = build_data_loader(cfg) print(3) path = "/usr4/alg504/cliao25/siammask/experiments/siammask_base/snapshot/checkpoint_e{}.pth" for epoch in range(1,21): if args.arch == 'Custom': from custom import Custom model = Custom(pretrain=True, anchors=cfg['anchors']) else: exit() print(4) if args.pretrained: model = load_pretrain(model, args.pretrained) model = model.cuda() #model.features.unfix((epoch - 1) / 20) optimizer, lr_scheduler = build_opt_lr(model, cfg, args, epoch) filepath = path.format(epoch) assert os.path.isfile(filepath) model, _, _, _, _ = restore_from(model, optimizer, filepath) #model = load_pretrain(model, filepath) model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda() model.train() device = torch.device('cuda') model = model.to(device) valid(val_loader, model, cfg) print("Done")