Esempio n. 1
0
def run_train():

    out_dir = RESULTS_DIR + '/xx10'
    initial_checkpoint = None
    #RESULTS_DIR + '/xx10/checkpoint/00002200_model.pth'

    pretrain_file = None
    skip = []

    ## setup  -----------------
    os.makedirs(out_dir + '/checkpoint', exist_ok=True)
    os.makedirs(out_dir + '/train', exist_ok=True)
    os.makedirs(out_dir + '/backup', exist_ok=True)
    backup_project_as_zip(PROJECT_PATH,
                          out_dir + '/backup/code.train.%s.zip' % IDENTIFIER)

    log = Logger()
    log.open(out_dir + '/log.train.txt', mode='a')
    log.write('\n--- [START %s] %s\n\n' % (IDENTIFIER, '-' * 64))
    log.write('\tSEED         = %u\n' % SEED)
    log.write('\tPROJECT_PATH = %s\n' % PROJECT_PATH)
    log.write('\tout_dir      = %s\n' % out_dir)
    log.write('\n')

    # fig = plt.figure(figsize=(5,5))
    # ax  = fig.add_subplot(111, projection='3d')

    fig1 = plt.figure(figsize=(5, 5))
    ax1 = fig1.add_subplot(111, projection='3d')

    fig2 = plt.figure(figsize=(5, 5))
    ax2 = fig2.add_subplot(111, projection='3d')

    ## net ----------------------
    log.write('** net setting **\n')
    net = Net().cuda()

    if initial_checkpoint is not None:
        log.write('\tinitial_checkpoint = %s\n' % initial_checkpoint)
        net.load_state_dict(
            torch.load(initial_checkpoint,
                       map_location=lambda storage, loc: storage))
        # cfg = load_pickle_file(out_dir +'/checkpoint/configuration.pkl')

    if pretrain_file is not None:
        log.write('\tpretrain_file = %s\n' % pretrain_file)
        net.load_pretrain(pretrain_file, skip)

    log.write('%s\n\n' % (type(net)))
    log.write('\n')

    ## optimiser ----------------------------------
    iter_accum = 1
    batch_size = 2
    valid_size = 2

    num_iters = 100
    iter_smooth = 20
    iter_log = 50
    iter_valid = 100
    iter_save   = [0, num_iters-1]\
                   + list(range(0,num_iters,200))#1*1000

    LR = None  #LR = StepLR([ (0, 0.01),  (200, 0.001),  (300, -1)])
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
                          lr=0.001 / iter_accum,
                          momentum=0.9,
                          weight_decay=0.0001)

    start_iter = 0
    start_epoch = 0.
    if initial_checkpoint is not None:
        checkpoint = torch.load(
            initial_checkpoint.replace('_model.pth', '_optimizer.pth'))
        start_iter = checkpoint['iter']
        start_epoch = checkpoint['epoch']

        rate = get_learning_rate(optimizer)  #load all except learning rate
        optimizer.load_state_dict(checkpoint['optimizer'])
        adjust_learning_rate(optimizer, rate)
        pass

    ## dataset ----------------------------------------
    log.write('** dataset setting **\n')

    train_dataset = DummyDataset('samples_train',
                                 mode='<not_used>',
                                 transform=train_augment)

    train_loader = DataLoader(
        train_dataset,
        sampler=RandomSampler(train_dataset),
        #sampler = SequentialSampler(train_dataset),
        batch_size=batch_size,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
        collate_fn=train_collate)

    valid_dataset = DummyDataset('samples_valid',
                                 mode='<not_used>',
                                 transform=train_augment)

    valid_loader = DataLoader(
        train_dataset,
        sampler=RandomSampler(valid_dataset),
        #sampler = SequentialSampler(train_dataset),
        batch_size=valid_size,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
        collate_fn=train_collate)

    # log.write('\ttrain_dataset.split = %s\n'%(train_dataset.split))
    # log.write('\tvalid_dataset.split = %s\n'%(valid_dataset.split))
    log.write('\tlen(train_dataset)  = %d\n' % (len(train_dataset)))
    log.write('\tlen(valid_dataset)  = %d\n' % (len(valid_dataset)))
    # log.write('\tlen(train_loader)   = %d\n'%(len(train_loader)))
    # log.write('\tlen(valid_loader)   = %d\n'%(len(valid_loader)))
    log.write('\tbatch_size  = %d\n' % (batch_size))
    log.write('\titer_accum  = %d\n' % (iter_accum))
    log.write('\tbatch_size*iter_accum  = %d\n' % (batch_size * iter_accum))
    log.write('\n')

    #<debug>========================================================================================
    if 0:
        #fig = plt.figure(figsize=(5,5))
        #ax  = fig.add_subplot(111, projection='3d')

        for tracklets, truths, datas, labels, lengths, indices in train_loader:

            batch_size = len(indices)
            print('batch_size=%d' % batch_size)

            tracklets = tracklets.data.cpu().numpy()
            truths = truths.data.cpu().numpy()
            split = np.cumsum(lengths)
            tracklets = np.split(tracklets, split)
            truths = np.split(truths, split)

            for b in range(batch_size):
                ax.clear()

                data = datas[b]
                x = data.x.values
                y = data.y.values
                z = data.z.values
                ax.plot(x, y, z, '.', color=[0.75, 0.75, 0.75], markersize=3)

                tracklet = tracklets[b].reshape(-1, 3, 3)
                truth = truths[b]

                pos = np.where(truth == 1)[0]
                for i in pos:
                    t = tracklet[i]
                    color = np.random.uniform(0, 1, (3))
                    ax.plot(t[:, 0],
                            t[:, 1],
                            t[:, 2],
                            '.-',
                            color=color,
                            markersize=6)

                set_figure(ax,
                           x_limit=(0, -100),
                           y_limit=(-20, 20),
                           z_limit=(500, 1000))
                plt.pause(0.01)
        plt.show()

    #<debug>========================================================================================

    ## start training here! ##############################################
    log.write('** start training here! **\n')
    log.write(' optimizer=%s\n' % str(optimizer))
    log.write(' momentum=%f\n' % optimizer.param_groups[0]['momentum'])
    log.write(' LR=%s\n\n' % str(LR))

    log.write(' images_per_epoch = %d\n\n' % len(train_dataset))
    log.write(
        ' rate    iter   epoch  num   | valid_loss        | train_loss       | batch_loss      |  time       \n'
    )
    log.write(
        '----------------------------------------------------------------------------------------------------\n'
    )

    train_loss = np.zeros(6, np.float32)
    valid_loss = np.zeros(6, np.float32)
    batch_loss = np.zeros(6, np.float32)
    rate = 0

    start = timer()
    j = 0
    i = 0

    while i < num_iters:  # loop over the dataset multiple times
        sum_train_loss = np.zeros(6, np.float32)
        sum = 0

        net.set_mode('train')
        optimizer.zero_grad()

        for tracklets, truths, datas, labels, lengths, indices in train_loader:

            batch_size = len(indices)
            i = j / iter_accum + start_iter
            epoch = (i - start_iter) * batch_size * iter_accum / len(
                train_dataset) + start_epoch
            num_products = epoch * len(train_dataset)

            if i % iter_valid == 0:
                net.set_mode('valid')
                valid_loss = evaluate(net, valid_loader)
                #print(valid_loss)
                net.set_mode('train')

                print('\r', end='', flush=True)
                log.write('%0.4f %5.1f k %6.1f %4.1f m |  %0.3f  |  %0.3f  |  %0.3f  | %s\n' % (\
                         rate, i/1000, epoch, num_products/1000000,
                         valid_loss[0], #valid_loss[1], valid_loss[2], valid_loss[3], #valid_loss[4], valid_loss[5],#valid_acc,
                         train_loss[0], #train_loss[1], train_loss[2], train_loss[3], #train_loss[4], train_loss[5],#train_acc,
                         batch_loss[0], #batch_loss[1], batch_loss[2], batch_loss[3], #batch_loss[4], batch_loss[5],#batch_acc,
                         time_to_str((timer() - start)/60)))
                time.sleep(0.01)

            #if 1:
            if i in iter_save:
                torch.save(net.state_dict(),
                           out_dir + '/checkpoint/%08d_model.pth' % (i))
                torch.save(
                    {
                        'optimizer': optimizer.state_dict(),
                        'iter': i,
                        'epoch': epoch,
                    }, out_dir + '/checkpoint/%08d_optimizer.pth' % (i))

            # learning rate schduler -------------
            if LR is not None:
                lr = LR.get_rate(i)
                if lr < 0: break
                adjust_learning_rate(optimizer, lr / iter_accum)
            rate = get_learning_rate(optimizer) * iter_accum

            # one iteration update  -------------
            tracklets = tracklets.cuda()
            truths = truths.cuda()

            logits = net.forward(tracklets)
            loss = F.binary_cross_entropy_with_logits(logits, truths)

            # accumulated update
            loss.backward()
            if j % iter_accum == 0:
                #torch.nn.utils.clip_grad_norm(net.parameters(), 1)
                optimizer.step()
                optimizer.zero_grad()

            # print statistics  ------------
            batch_loss = np.array((
                loss.cpu().data.numpy(),
                0,
                0,
                0,
                0,
                0,
            ))
            sum_train_loss += batch_loss
            sum += 1
            if i % iter_smooth == 0:
                train_loss = sum_train_loss / sum
                sum_train_loss = np.zeros(6, np.float32)
                sum = 0


            print('\r%0.4f %5.1f k %6.1f %4.1f m |  %0.3f  |  %0.3f  |  %0.3f  | %s  %d,%d,%s' % (\
                         rate, i/1000, epoch, num_products/1000000,
                         valid_loss[0], #valid_loss[1], valid_loss[2], valid_loss[3], #valid_loss[4], valid_loss[5],#valid_acc,
                         train_loss[0], #train_loss[1], train_loss[2], train_loss[3], #train_loss[4], train_loss[5],#train_acc,
                         batch_loss[0], #batch_loss[1], batch_loss[2], batch_loss[3], #batch_loss[4], batch_loss[5],#batch_acc,
                         time_to_str((timer() - start)/60) ,i,j, ''), end='',flush=True)#str(inputs.size()))
            j = j + 1

            #<debug> ===================================================================
            if 1:
                #if i%200==0:
                net.set_mode('test')
                with torch.no_grad():
                    logits = net.forward(tracklets)

                tracklets = tracklets.data.cpu().numpy()
                probs = np_sigmoid(logits.data.cpu().numpy())
                truths = truths.data.cpu().numpy()

                batch_size = len(indices)
                split = np.cumsum(lengths)
                tracklets = np.split(tracklets, split)
                probs = np.split(probs, split)
                truths = np.split(truths, split)

                for b in range(batch_size):
                    ax1.clear()
                    ax2.clear()

                    data = datas[b]
                    x = data.x.values
                    y = data.y.values
                    z = data.z.values
                    ax1.plot(x,
                             y,
                             z,
                             '.',
                             color=[0.75, 0.75, 0.75],
                             markersize=3)
                    ax2.plot(x,
                             y,
                             z,
                             '.',
                             color=[0.75, 0.75, 0.75],
                             markersize=3)

                    tracklet = tracklets[b]
                    prob = probs[b]
                    truth = truths[b]

                    #idx = np.where(prob>0.5)[0]
                    #for i in idx:
                    threshold = 0.5
                    for i in range(len(truth)):
                        t = tracklet[i].reshape(-1, 3)

                        if prob[i] > threshold and truth[i] > 0.5:  #hit
                            color = np.random.uniform(0, 1, (3))
                            ax1.plot(t[:, 0],
                                     t[:, 1],
                                     t[:, 2],
                                     '.-',
                                     color=color,
                                     markersize=6)

                        if prob[i] > threshold and truth[i] < 0.5:  #fp
                            ax2.plot(t[:, 0],
                                     t[:, 1],
                                     t[:, 2],
                                     '.-',
                                     color=[0, 0, 0],
                                     markersize=6)
                        if prob[i] < threshold and truth[i] > 0.5:  #miss
                            ax2.plot(t[:, 0],
                                     t[:, 1],
                                     t[:, 2],
                                     '.-',
                                     color=[1, 0, 0],
                                     markersize=6)

                    set_figure(ax1,
                               title='hit   @sample%d' % indices[b],
                               x_limit=(0, -100),
                               y_limit=(-20, 20),
                               z_limit=(500, 1000))
                    set_figure(ax2,
                               title='error @sample%d' % indices[b],
                               x_limit=(0, -100),
                               y_limit=(-20, 20),
                               z_limit=(500, 1000))
                    plt.pause(0.01)
                    #fig.savefig(out_dir +'/train/%05d.png'%indices[b])
                pass

                net.set_mode('train')
            #<debug> ===================================================================

        pass  #-- end of one data loader --
    pass  #-- end of all iterations --

    if 1:  #save last
        torch.save(net.state_dict(),
                   out_dir + '/checkpoint/%d_model.pth' % (i))
        torch.save(
            {
                'optimizer': optimizer.state_dict(),
                'iter': i,
                'epoch': epoch,
            }, out_dir + '/checkpoint/%d_optimizer.pth' % (i))

    log.write('\n')
Esempio n. 2
0
def main(data_dir,
         results_dir,
         weights_dir,
         which_dataset,
         image_resize,
         image_crop_size,
         exp_num,
         max_epochs,
         batch_size,
         samples_update_size,
         num_workers=4,
         lr=5e-6,
         weight_decay=1e-5):
    """
    This is the main function. You need to interface only with this function to train. (It will record all the results)
    Once you have trained use create_db.py to create the embeddings and then use the inference_on_single_image.py to test
    
    Arguments:
        data_dir    : parent directory for data
        results_dir : directory to store the results (Make sure you create this directory first)
        weights_dir : directory to store the weights (Make sure you create this directory first)
        which_dataset : "oxford" or "paris" 
        image_resize : resize to this size
        image_crop_size : square crop size
        exp_num     : experiment number to record the log and results
        max_epochs  : maximum epochs to run
        batch_size  : batch size (I used 5)
        samples_update_size : Number of samples the network should see before it performs one parameter update (I used 64)
    
    Keyword Arguments:
        num_workers : default 4
        lr      : Initial learning rate (default 5e-6)
        weight_decay: default 1e-5

    Eg run:
        if __name__ == '__main__':
            main(data_dir="./data/", results_dir="./results", weights_dir="./weights",
            which_dataset="oxbuild", image_resize=460, image_crop_size=448,
            exp_num=3, max_epochs=10, batch_size=5, samples_update_size=64)
    """
    # Define directories
    labels_dir = os.path.join(data_dir, which_dataset, "gt_files")
    image_dir = os.path.join(data_dir, which_dataset, "images")

    # Create Query extractor object
    q_train = QueryExtractor(labels_dir, image_dir, subset="train")
    q_valid = QueryExtractor(labels_dir, image_dir, subset="valid")

    # Create transformss
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    transforms_train = transforms.Compose([
        transforms.Resize(image_resize),
        transforms.RandomResizedCrop(image_crop_size, scale=(0.8, 1.2)),
        transforms.ColorJitter(brightness=(0.80, 1.20)),
        transforms.RandomHorizontalFlip(p=0.50),
        transforms.RandomChoice([
            transforms.RandomRotation(15),
            transforms.Grayscale(num_output_channels=3),
        ]),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    transforms_valid = transforms.Compose([
        transforms.Resize(image_resize),
        transforms.CenterCrop(image_crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    # Create dataset
    dataset_train = VggImageRetrievalDataset(labels_dir,
                                             image_dir,
                                             q_train,
                                             transforms=transforms_train)
    dataset_valid = VggImageRetrievalDataset(labels_dir,
                                             image_dir,
                                             q_valid,
                                             transforms=transforms_valid)

    # Create dataloader
    train_loader = DataLoader(dataset_train,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=True)
    valid_loader = DataLoader(dataset_valid,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=False)

    # Create cuda parameters
    use_cuda = torch.cuda.is_available()
    np.random.seed(2020)
    torch.manual_seed(2020)
    device = torch.device("cuda" if use_cuda else "cpu")

    # Create embedding network
    embedding_model = create_embedding_net()
    model = TripletNet(embedding_model)
    model.to(device)

    # Create optimizer and scheduler
    optimizer = optim.Adam(model.parameters(),
                           lr=lr,
                           weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    # Create log file
    log_file = open(os.path.join(results_dir, "log-{}.txt".format(exp_num)),
                    "w+")
    log_file.write("----------Experiment {}----------\n".format(exp_num))
    log_file.write("Dataset = {}, Image sizes = {}, {}\n".format(
        which_dataset, image_resize, image_crop_size))

    # Creat batch update value
    update_batch = int(math.ceil(float(samples_update_size) / batch_size))
    model_name = "{}-exp-{}.pth".format(which_dataset, exp_num)
    loss_plot_save_path = os.path.join(
        results_dir, "{}-loss-exp-{}.png".format(which_dataset, exp_num))

    # Print stats before starting training
    print("Running VGG Image Retrieval Training script")
    print("Dataset used\t\t:{}".format(which_dataset))
    print("Max epochs\t\t: {}".format(max_epochs))
    print("Gradient update\t\t: every {} batches ({} samples)".format(
        update_batch, samples_update_size))
    print("Initial Learning rate\t: {}".format(lr))
    print("Image resize, crop size\t: {}, {}".format(image_resize,
                                                     image_crop_size))
    print("Available device \t:", device)

    # Train the triplet network
    tr_hist, val_hist = train_model(model,
                                    device,
                                    optimizer,
                                    scheduler,
                                    train_loader,
                                    valid_loader,
                                    epochs=max_epochs,
                                    update_batch=update_batch,
                                    model_name=model_name,
                                    save_dir=weights_dir,
                                    log_file=log_file)

    # Close the file
    log_file.close()

    # Plot and save
    plot_history(tr_hist,
                 val_hist,
                 "Triplet Loss",
                 loss_plot_save_path,
                 labels=["train", "validation"])


# if __name__ == '__main__':
#     main(data_dir="./data/", results_dir="./results", weights_dir="./weights",
#         which_dataset="oxbuild", image_resize=460, image_crop_size=448,
#         exp_num=3, max_epochs=10, batch_size=5, samples_update_size=64)