예제 #1
0
def main():
    print("Extract data")
    unzip_data()

    print("Split on train and test")
    split_on_train_and_test()

    print("Create datasets")
    train_ds, test_ds = prepare_datasets()

    print("Create data loaders")
    train_sampler = SiameseSampler(train_ds, random_state=RS)
    test_sampler = SiameseSampler(test_ds, random_state=RS)
    train_data_loader = DataLoader(train_ds,
                                   batch_size=BATCH_SIZE,
                                   sampler=train_sampler,
                                   num_workers=4)
    test_data_loader = DataLoader(test_ds,
                                  batch_size=BATCH_SIZE,
                                  sampler=test_sampler,
                                  num_workers=4)

    print("Build computational graph")
    mobilenet = mobilenet_v2(pretrained=True)
    # remove last layer
    mobilenet = torch.nn.Sequential(*(list(mobilenet.children())[:-1]))
    siams = SiameseNetwork(twin_net=TransferTwinNetwork(
        base_model=mobilenet, output_dim=EMBEDDING_DIM))
    siams.to(DEVICE)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(siams.parameters(), lr=LR)

    print("Train model")
    siams = train(siams, criterion, optimizer, train_data_loader,
                  test_data_loader)

    print("Save model")
    torch.save(siams.twin_net.state_dict(), 'models/twin.pt')
예제 #2
0
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        help="Path of model checkpoint to be used for inference.",
        required=True
    )
    parser.add_argument(
        '-o',
        '--out_path',
        type=str,
        help="Path for saving tensorrt model.",
        required=True
    )

    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    checkpoint = torch.load(args.checkpoint)
    model = SiameseNetwork(backbone=checkpoint['backbone'])
    model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    torch.onnx.export(model, (torch.rand(1, 3, 224, 224).to(device), torch.rand(1, 3, 224, 224).to(device)), args.out_path, input_names=['input'],
                      output_names=['output'], export_params=True)
    
    onnx_model = onnx.load(args.out_path)
    onnx.checker.check_model(onnx_model)
예제 #3
0
branch = BranchNetwork()
net = SiameseNetwork(branch)

if os.path.isfile(model_name):
    checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
    net.load_state_dict(checkpoint['state_dict'])
    print('load model file from {}.'.format(model_name))
else:
    print('Error: file not found at {}'.format(model_name))
    sys.exit()

# 3: setup computation device
device = 'cpu'
if torch.cuda.is_available():
    device = torch.device('cuda:{}'.format(cuda_id))
    net = net.to(device)
    cudnn.benchmark = True
print('computation device: {}'.format(device))

features = []
with torch.no_grad():
    for i in range(len(data_loader)):
        x, _ = data_loader[i]
        x = x.to(device)
        feat = net.feature_numpy(x) # N x C

        features.append(feat)
        # append to the feature list

        if i%100 == 0:
            print('finished {} in {}'.format(i+1, len(data_loader)))
예제 #4
0
# 3: setup computation device
if resume:
    if os.path.isfile(resume):
        checkpoint = torch.load(resume,
                                map_location=lambda storage, loc: storage)
        net.load_state_dict(checkpoint['state_dict'])
        print('resume from {}.'.format(resume))
    else:
        print('file not found at {}'.format(resume))
else:
    print('Learning from scratch')

device = 'cpu'
if torch.cuda.is_available():
    device = torch.device('cuda:{}'.format(cuda_id))
    net = net.to(device)
    criterion = ContrastiveLoss(margin=1.0).cuda(device)
    cudnn.benchmark = True

print('computation device: {}'.format(device))


def save_checkpoint(state, filename):
    file_path = os.path.join(filename)
    torch.save(state, file_path)


pdist = nn.PairwiseDistance(p=2)
for epoch in range(num_epoch):
    net.train()
    train_loader._sample_once()