def main(): args = get_args() check_mkdir('./logs') logging_name = './logs/{}_{}_lr_{}.txt'.format(args.networks, args.optimizer, args.lr) logger.setLevel(logging.DEBUG) formatter = logging.Formatter('[%(asctime)10s][%(levelname)s] %(message)s', datefmt='%Y/%m/%d %H:%M:%S') stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) file_handler = logging.FileHandler(logging_name) file_handler.setFormatter(formatter) logger.addHandler(stream_handler) logger.addHandler(file_handler) logger.info('arguments:{}'.format(" ".join(sys.argv))) if args.ignite is False: device = 'cuda' if torch.cuda.is_available() else 'cpu' model = mobile_hair.MobileMattingFCN() if torch.cuda.is_available(): if torch.cuda.device_count() > 1: print('multi gpu') model = torch.nn.DataParallel(model) model.to(device) loss = mobile_hair.HairMattingLoss() optimizer = get_optimizer(args.optimizer, model, args.lr, args.momentum) # torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=0.0001, betas=(0.9, 0.999)) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min') train_without_ignite(model, loss, batch_size=args.batch_size, img_size=args.img_size, epochs=args.epochs, lr=args.lr, num_workers=args.num_workers, optimizer=optimizer, logger=logger, gray_image=True, scheduler=scheduler, viz=args.visdom) else: train_with_ignite(networks=args.networks, dataset=args.dataset, data_dir=args.data_dir, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, num_workers=args.num_workers, optimizer=args.optimizer, momentum=args.momentum, img_size=args.img_size, logger=logger)
ckpt_dir = args.ckpt_dir img_dir = args.img_dir network = args.networks.lower() save_dir = args.save_dir device = 'cuda' if args.use_gpu else 'cpu' assert os.path.exists(ckpt_dir) assert os.path.exists(img_dir) assert os.path.exists(os.path.split(save_dir)[0]) os.makedirs(save_dir, exist_ok=True) # prepare network with trained parameters net = get_network(network).to(device) #print(net) model = mobile_hair.MobileMattingFCN() #print(model) model.load_state_dict(torch.load(ckpt_dir)) #print(ckpt_dir) #state = torch.load_state_dict(ckpt_dir) #state = model #net.load_state_dict(state['weight']) print(model) test_image_transforms = std_trnsf.Compose([ std_trnsf.ToTensor(), std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) durations = list()