예제 #1
0
train_path = './data/formatted_trainval/shanghaitech_part_A_patches_9/train'
train_gt_path = './data/formatted_trainval/shanghaitech_part_A_patches_9/train_den'
val_path = './data/formatted_trainval/shanghaitech_part_A_patches_9/val'
val_gt_path = './data/formatted_trainval/shanghaitech_part_A_patches_9/val_den'

#training configuration
num_epochs=2000
lr = 0.00001
momentum = 0.9

# load net
net = MCNN()
weights_normal_init(net, dev=0.01)
net.to(device)
net.train()

optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-5)
criterion = nn.MSELoss()

if not os.path.exists(checkpoint_dir):
    os.mkdir(checkpoint_dir)

print('Loading training and validation datasets')
#train_data_loader = ImageDataLoader(train_path, train_gt_path, shuffle=True, gt_downsample=True, pre_load=True)
#val_data_loader = ImageDataLoader(val_path, val_gt_path, shuffle=False, gt_downsample=True, pre_load=True)
train_dataset = Shanghai_Dataset(train_path, train_gt_path, gt_downsample=True, pre_load=True)
train_data_loader =  DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
val_dataset = Shanghai_Dataset(val_path, val_gt_path, gt_downsample=True, pre_load=True)
val_data_loader =  DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)
best_mae = sys.maxint
예제 #2
0
                                                   pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True)

    # Initialize model, loss function, and optimizer
    model = MCNN()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Iterate through training set and return input image and ground truth density map
    for batch_idx, (img, gt_dmap) in enumerate(train_loader):

        # Make sure we are in training mode
        model.train()

        # Torch accumulates gradients so we must zero these out for each batch
        optimizer.zero_grad()

        # Place data and labels in variable to track gradients and place on the GPU if available
        img = Variable(img.float())
        gt_dmap = Variable(gt_dmap.float())
        if torch.cuda.is_available():
            img.cuda()
            pred_dmap.cuda()

        pred_map = model(img)
        import pdb
        pdb.set_trace()