Exemplo n.º 1
0
def main():
    args = FLAGS.parse_args()
    p = create_config(args.config_env, args.config_exp, args.tb_run)
    print(colored(p, 'red'))

    # CUDNN
    torch.backends.cudnn.benchmark = True

    # Data
    print(colored('Get dataset and dataloaders', 'blue'))
    train_transformations = get_train_transformations(p)
    val_transformations = get_val_transformations(p)
    train_dataset = get_train_dataset(p,
                                      train_transformations,
                                      split='train',
                                      to_neighbors_dataset=True)
    val_dataset = get_val_dataset(p,
                                  val_transformations,
                                  to_neighbors_dataset=True)
    train_dataloader = get_train_dataloader(p, train_dataset)
    val_dataloader = get_val_dataloader(p, val_dataset)
    print('Train transforms:', train_transformations)
    print('Validation transforms:', val_transformations)
    print('Train samples %d - Val samples %d' %
          (len(train_dataset), len(val_dataset)))

    # Tensorboard writer
    writer = SummaryWriter(log_dir=p['scan_tb_dir'])

    # Model
    print(colored('Get model', 'blue'))
    model = get_model(p, p['pretext_model'])
    print(model)
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # Optimizer
    print(colored('Get optimizer', 'blue'))
    optimizer = get_optimizer(p, model, p['update_cluster_head_only'])
    print(optimizer)

    # Warning
    if p['update_cluster_head_only']:
        print(colored('WARNING: SCAN will only update the cluster head',
                      'red'))

    # Loss function
    print(colored('Get loss', 'blue'))
    criterion = get_criterion(p)
    criterion.cuda()
    print(criterion)

    # Checkpoint
    if os.path.exists(p['scan_checkpoint']):
        print(
            colored('Restart from checkpoint {}'.format(p['scan_checkpoint']),
                    'blue'))
        checkpoint = torch.load(p['scan_checkpoint'], map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        best_loss_head = checkpoint['best_loss_head']

    else:
        print(
            colored('No checkpoint file at {}'.format(p['scan_checkpoint']),
                    'blue'))
        start_epoch = 0
        best_loss = 1e4
        best_loss_head = None

    # Main loop
    print(colored('Starting main loop', 'blue'))

    for epoch in range(start_epoch, p['epochs']):
        print(colored('Epoch %d/%d' % (epoch + 1, p['epochs']), 'yellow'))
        print(colored('-' * 15, 'yellow'))

        # Adjust lr
        lr = adjust_learning_rate(p, optimizer, epoch)
        print('Adjusted learning rate to {:.5f}'.format(lr))

        # Train
        print('Train ...')
        scan_train(train_dataloader, model, criterion, optimizer, epoch,
                   writer, p['update_cluster_head_only'])

        # Evaluate
        print('Make prediction on validation set ...')
        predictions = get_predictions(p, val_dataloader, model)

        print('Evaluate based on SCAN loss ...')
        scan_stats = scan_evaluate(predictions)
        print(scan_stats)
        lowest_loss_head = scan_stats['lowest_loss_head']
        lowest_loss = scan_stats['lowest_loss']

        if lowest_loss < best_loss:
            print('New lowest loss on validation set: %.4f -> %.4f' %
                  (best_loss, lowest_loss))
            print('Lowest loss head is %d' % lowest_loss_head)
            best_loss = lowest_loss
            best_loss_head = lowest_loss_head
            torch.save(
                {
                    'model': model.module.state_dict(),
                    'head': best_loss_head
                }, p['scan_model'])

        else:
            print('No new lowest loss on validation set: %.4f -> %.4f' %
                  (best_loss, lowest_loss))
            print('Lowest loss head is %d' % best_loss_head)

        print('Evaluate with hungarian matching algorithm ...')
        clustering_stats = hungarian_evaluate(lowest_loss_head,
                                              predictions,
                                              compute_confusion_matrix=False,
                                              tf_writer=writer,
                                              epoch=epoch)
        print(clustering_stats)

        # Checkpoint
        print('Checkpoint ...')
        torch.save(
            {
                'optimizer': optimizer.state_dict(),
                'model': model.state_dict(),
                'epoch': epoch + 1,
                'best_loss': best_loss,
                'best_loss_head': best_loss_head
            }, p['scan_checkpoint'])

    # Evaluate and save the final model
    print(
        colored('Evaluate best model based on SCAN metric at the end', 'blue'))
    model_checkpoint = torch.load(p['scan_model'], map_location='cpu')
    model.module.load_state_dict(model_checkpoint['model'])
    predictions, features, thumbnails = get_predictions(p,
                                                        val_dataloader,
                                                        model,
                                                        return_features=True,
                                                        return_thumbnails=True)
    writer.add_embedding(features, predictions[0]['targets'], thumbnails,
                         p['epochs'], p['scan_tb_dir'])
    clustering_stats = hungarian_evaluate(
        model_checkpoint['head'],
        predictions,
        class_names=val_dataset.dataset.classes,
        compute_confusion_matrix=True,
        confusion_matrix_file=os.path.join(p['scan_dir'],
                                           'confusion_matrix.png'))
    print(clustering_stats)
Exemplo n.º 2
0
        scan_train(train_dataloader, model, criterion, optimizer, epoch, p['update_cluster_head_only'])

        # Evaluate 

        #!!!!!!!!!!!!!!!!!Skipping the next lines because we are not evaluating YET. 
        

        print('Make prediction on validation set ...')
        predictions = get_predictions(p, val_dataloader, model)   #inputting the train data to get the clusters !! 
<<<<<<< HEAD
        
=======
        continue
>>>>>>> db23360031c529a04f0a144b63e5f3fe49feb44f
        print('Evaluate based on SCAN loss ...')
        scan_stats = scan_evaluate(predictions)
        print(scan_stats)
        lowest_loss_head = scan_stats['lowest_loss_head']
        lowest_loss = scan_stats['lowest_loss']
       
        if lowest_loss < best_loss:
            print('New lowest loss on validation set: %.4f -> %.4f' %(best_loss, lowest_loss))
            print('Lowest loss head is %d' %(lowest_loss_head))
            best_loss = lowest_loss
            best_loss_head = lowest_loss_head
            torch.save({'model': model.module.state_dict(), 'head': best_loss_head}, p['scan_model'])

        else:
            print('No new lowest loss on validation set: %.4f -> %.4f' %(best_loss, lowest_loss))
            print('Lowest loss head is %d' %(best_loss_head))