top5_a = top5.get_average_results() avg_score = top1_a["verb"] + top1_a["value"] + top1_a["value-all"] + top5_a["verb"] + \ top5_a["value"] + top5_a["value-all"] + \ top5_a["value*"] + top5_a["value-all*"] avg_score /= 8 print("Average :{:.2f} {} {}".format( avg_score * 100, format_dict(top1_a, "{:.2f}", "1-"), format_dict(top5_a, "{:.2f}", "5-"))) elif args.command == "predict": print("command = predict") encoder = torch.load(encoding_file) print("creating model...") model = baseline_crf(encoder, cnn_type=cnn_type) print("loading model weights...") model.load_state_dict(torch.load(weights_file)) model.to(device) folder_dataset = imSituSimpleImageFolder(image_dir, model.dev_preprocess()) image_loader = torch.utils.data.DataLoader(folder_dataset, batch_size=batch_size, shuffle=False) predict_human_readable(image_loader, folder_dataset, model, output_dir, top_k)
def main(): import argparse parser = argparse.ArgumentParser( description= "imsitu Situation CRF. Training, evaluation, prediction and features.") parser.add_argument("--command", choices=["train", "eval", "predict", "features"], required=True) parser.add_argument( "--output_dir", help="location to put output, such as models, features, predictions") parser.add_argument("--image_dir", default="./resized_256", help="location of images to process") parser.add_argument("--dataset_dir", default="./", help="location of train.json, dev.json, ect.") parser.add_argument( "--dataset_postfix", default="_v1", help="dataset postfix for different version, _v1, etc.") parser.add_argument("--weights_file", help="the model to start from") parser.add_argument("--encoding_file", help="a file corresponding to the encoder") parser.add_argument("--cnn_type", choices=["resnet_34", "resnet_50", "resnet_101"], default="resnet_101", help="the cnn to initilize the crf with") parser.add_argument("--batch_size", default=64, help="batch size for training", type=int) parser.add_argument("--learning_rate", default=1e-5, help="learning rate for ADAM", type=float) parser.add_argument("--weight_decay", default=5e-4, help="learning rate decay for ADAM", type=float) parser.add_argument("--eval_frequency", default=100, help="evaluate on dev set every N training steps", type=int) parser.add_argument("--training_epochs", default=20, help="total number of training epochs", type=int) parser.add_argument( "--eval_file", default="dev.json", help="the dataset file to evaluate on, ex. dev.json test.json") parser.add_argument("--top_k", default="10", type=int, help="topk to use for writing predictions to file") parser.add_argument("--device_array", nargs='+', type=int, default=[0]) parser.add_argument("--use_wandb", action='store_true') args = parser.parse_args() if args.command == "train": print("command = training") train_set = json.load( open(args.dataset_dir + "/train" + args.dataset_postfix + ".json")) dev_set = json.load( open(args.dataset_dir + "/dev" + args.dataset_postfix + ".json")) if args.encoding_file is None: encoder = imSituVerbRoleLocalNounEncoder(train_set) torch.save(encoder, args.output_dir + "/encoder") else: encoder = torch.load(args.encoding_file) model = BaselineCRF(encoder, cnn_type=args.cnn_type, device_array=args.device_array) if args.weights_file is not None: model.load_state_dict(torch.load(args.weights_file)) dataset_train = imSituSituation(args.image_dir, train_set, encoder, model.train_preprocess()) dataset_dev = imSituSituation(args.image_dir, dev_set, encoder, model.dev_preprocess()) batch_size = args.batch_size * len(args.device_array) train_loader = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=True) # , num_workers = 3) dev_loader = torch.utils.data.DataLoader( dataset_dev, batch_size=batch_size, shuffle=True) # , num_workers = 3) model.cuda(args.device_array[0]) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) train_model(args.training_epochs, args.eval_frequency, train_loader, dev_loader, model, encoder, optimizer, args.output_dir, args.device_array, args) elif args.command == "eval": print("command = evaluating") eval_file = json.load(open(args.dataset_dir + "/" + args.eval_file)) if args.encoding_file is None: print("expecting encoder file to run evaluation") exit() else: encoder = torch.load(args.encoding_file) print("creating model...") model = BaselineCRF(encoder, cnn_type=args.cnn_type) if args.weights_file is None: print("expecting weight file to run features") exit() print("loading model weights...") model.load_state_dict(torch.load(args.weights_file)) model.cuda() dataset = imSituSituation(args.image_dir, eval_file, encoder, model.dev_preprocess()) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=3) (top1, top5) = eval_model(loader, encoder, model) top1_a = top1.get_average_results() top5_a = top5.get_average_results() avg_score = top1_a["verb"] + top1_a["value"] + top1_a["value-all"] + top5_a["verb"] + \ top5_a["value"] + top5_a["value-all"] + \ top5_a["value*"] + top5_a["value-all*"] avg_score /= 8 print("Average :{:.2f} {} {}".format( avg_score * 100, format_dict(top1_a, "{:.2f}", "1-"), format_dict(top5_a, "{:.2f}", "5-"))) elif args.command == "features": print("command = features") if args.encoding_file is None: print("expecting encoder file to run features") exit() else: encoder = torch.load(args.encoding_file) print("creating model...") model = BaselineCRF(encoder, cnn_type=args.cnn_type) if args.weights_file is None: print("expecting weight file to run features") exit() print("loading model weights...") model.load_state_dict(torch.load(args.weights_file)) model.cuda() folder_dataset = imSituSimpleImageFolder(args.image_dir, model.dev_preprocess()) image_loader = torch.utils.data.DataLoader(folder_dataset, batch_size=args.batch_size, shuffle=False, num_workers=3) compute_features(image_loader, folder_dataset, model, args.output_dir) elif args.command == "predict": print("command = predict") if args.encoding_file is None: print("expecting encoder file to run features") exit() else: encoder = torch.load(args.encoding_file) print("creating model...") model = BaselineCRF(encoder, cnn_type=args.cnn_type) if args.weights_file is None: print("expecting weight file to run features") exit() print("loading model weights...") model.load_state_dict(torch.load(args.weights_file)) model.cuda() folder_dataset = imSituSimpleImageFolder(args.image_dir, model.dev_preprocess()) image_loader = torch.utils.data.DataLoader(folder_dataset, batch_size=args.batch_size, shuffle=False, num_workers=3) predict_human_readable(image_loader, folder_dataset, encoder, model, args.output_dir, args.top_k)