def validate(net, val_data, ctx, use_threads=True):
    """Test a model."""
    outputs = []
    labels = []
    ctx_cpu = mx.cpu()

    for batch in tqdm(val_data, desc='Computing test embeddings'):
        data = mx.gluon.utils.split_and_load(batch[0],
                                             ctx_list=ctx,
                                             batch_axis=0,
                                             even_split=False)
        label = mx.gluon.utils.split_and_load(batch[1],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)

        for x in data:
            outputs.append(net(x).as_in_context(ctx_cpu))
        labels += [l.as_in_context(ctx_cpu) for l in label]

    outputs = mx.nd.concatenate(outputs, axis=0)
    labels = mx.nd.concatenate(labels, axis=0)
    return evaluate(outputs,
                    labels,
                    val_data._dataset.num_classes(),
                    use_threads=use_threads)
Beispiel #2
0
def validate(dataloader, models, context, static_proxies, similarity='cosine'):
    outputs = []
    labels = []
    ctx_cpu = mx.cpu()

    for batch in tqdm(dataloader, desc='Computing test embeddings'):
        data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False)
        label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False)
        neg_labels = mx.gluon.utils.split_and_load(batch[2], ctx_list=context, batch_axis=0, even_split=False)
        for x, l, nl in zip(data, label, neg_labels):
            ensembles = []
            for m in models:
                m.collect_params().reset_ctx(context)  # move model to GPU
                if static_proxies:
                    ensembles.append(m(x).as_in_context(ctx_cpu))
                else:
                    ensembles.append(
                        m(x, mx.nd.zeros_like(l, ctx=x.context), mx.nd.zeros_like(nl, ctx=x.context))[0].as_in_context(
                            ctx_cpu))
                mx.nd.waitall()
                m.collect_params().reset_ctx(ctx_cpu)  # move model to CPU
            outputs.append(mx.nd.concat(*ensembles, dim=1))
        labels += [x.as_in_context(ctx_cpu) for x in label]

    outputs = mx.nd.concatenate(outputs, axis=0)
    labels = mx.nd.concatenate(labels, axis=0)
    logging.info('Evaluating with %s distance' % similarity)
    return evaluate(outputs, labels, dataloader._dataset.num_classes(), similarity=similarity)
def run_full_pipeline(metadata: Metadata, model_wrapper: ModelWrapper,
                      model_type: ModelType):
    print_cuda_info()
    restore_transformations = False

    if os.path.exists(os.path.join(metadata.model_dir, MODEL_FILENAME)):
        Metadata.restore_from_json(metadata,
                                   f"{model_wrapper.model_dir}/metadata.json")
        if metadata.training_finished:
            print(
                f"\n\n\nModel at {metadata.model_dir} already finished training."
            )
            return
        else:
            print(
                f"\n\n\nModel at {metadata.model_dir} already exists, restoring this model."
            )
            model_wrapper.load()
            restore_transformations = True
    else:
        os.makedirs(metadata.model_dir, exist_ok=True)

    metadata.num_params = model_wrapper.num_parameters()

    dataset = Dataset(metadata)

    train_loader, val_loader, test_loader = create_data_loaders(
        dataset,
        metadata,
        model=model_type,
        transformations=TransformationsManager.get_transformations(
            metadata.transformations),
    )

    if restore_transformations:
        train_loader.tm.transformations_count = metadata.train_transformations[
            "transformations_count"]
        val_loader.tm.transformations_count = metadata.val_transformations[
            "transformations_count"]

    train_model(
        metadata=metadata,
        wrapper=model_wrapper,
        train_loader=train_loader,
        val_loader=val_loader,
        gan=(model_type == ModelType.SEGAN),
    )

    test_mse_loss = evaluate(model_wrapper, test_loader)
    print(f"Test set mse loss: {test_mse_loss}")

    metadata.test_mse_loss = test_mse_loss
    metadata.training_finished = True
    metadata.test_transformations = test_loader.tm.get_info()

    metadata.save_to_json(TRAINING_RESULTS_FILENAME)
def train_model(
    metadata: Metadata,
    wrapper: ModelWrapper,
    train_loader: DataLoader,
    val_loader: DataLoader,
    gan: bool = False,
):
    patience_ctr = 0

    print("\nStarting training.")

    training_summary = TrainingSummary.get(train_loader.batch_size, metadata,
                                           gan)
    for i, epoch in enumerate(range(metadata.current_epoch, metadata.epochs)):
        # train on whole train set
        use_transformations = i >= metadata.warmup_epochs
        train_loader.tm_active = use_transformations
        val_loader.tm_active = use_transformations

        for batch_ctr, batch in enumerate(
                tqdm(train_loader, desc=f"Epoch {epoch}.")):
            for mini_batch in batch:
                batch_loss = wrapper.train_step(mini_batch)

                current_batch_size = mini_batch[0].shape[0]
                training_summary.add_step(
                    batch_loss, current_batch_size=current_batch_size)

        # evaluate on val set
        val_loss = evaluate(wrapper, val_loader)

        training_summary.add_training_epoch(epoch)
        training_summary.add_epoch_val_loss(val_loss)
        training_summary.update_metadata(metadata)
        metadata.train_transformations = train_loader.tm.get_info()
        metadata.val_transformations = val_loader.tm.get_info()

        metadata.current_epoch = epoch + 1
        metadata.save_to_json(TRAINING_RESULTS_FILENAME)

        if training_summary.val_loss_improved():
            wrapper.save()
            patience_ctr = 0
        else:
            patience_ctr += 1

        if patience_ctr == metadata.patience:
            print(
                f"\nStopping training since validation loss hasn't improved for {metadata.patience} epochs."
            )
            break

    wrapper.load()
    print(f"\nTraining finished.")
def run_overfit(metadata: Metadata, model_wrapper: ModelWrapper,
                model_type: ModelType):
    print_cuda_info()

    os.makedirs(metadata.model_dir, exist_ok=True)

    if os.path.exists(os.path.join(metadata.model_dir, MODEL_FILENAME)):
        Metadata.restore_from_json(metadata,
                                   f"{model_wrapper.model_dir}/metadata.json")
        if metadata.training_finished:
            print(
                f"\n\n\nModel at {metadata.model_dir} already finished training."
            )
            return
        else:
            print(
                f"\n\n\nModel at {metadata.model_dir} already exists, restoring this model."
            )
            model_wrapper.load()
    else:
        os.makedirs(metadata.model_dir, exist_ok=True)

    metadata.num_params = model_wrapper.num_parameters()

    dataset = Dataset(metadata)

    train_loader, _, _ = create_data_loaders(
        dataset,
        metadata,
        model=model_type,
        transformations=TransformationsManager.get_transformations("none"),
    )

    eval_loader = deepcopy(train_loader)
    # if training gan, disable additional noisy inputs from datasets
    if model_type == ModelType.SEGAN:
        eval_loader.train_gan = False

    train_model(
        metadata=metadata,
        wrapper=model_wrapper,
        train_loader=train_loader,
        val_loader=eval_loader,
        gan=(model_type == ModelType.SEGAN),
    )

    test_mse_loss = evaluate(model_wrapper, eval_loader)
    print(f"Final mse loss: {test_mse_loss}")

    metadata.final_mse_loss = test_mse_loss
    metadata.training_finished = True
    metadata.save_to_json(TRAINING_RESULTS_FILENAME)
Beispiel #6
0
def validate(net, val_data, ctx, binarize=True, nmi=True, similarity='euclidean'):
    """Test a model."""
    outputs = []
    labels = []
    ctx_cpu = mx.cpu()

    for batch in tqdm(val_data, desc='Computing test embeddings'):
        data = mx.gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
        label = mx.gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)

        for x in data:
            embs = net(x)[0]
            if binarize:
                embs = embs > 0
            outputs.append(embs.as_in_context(ctx_cpu))
        labels += [x.as_in_context(ctx_cpu) for x in label]

    outputs = mx.nd.concatenate(outputs, axis=0)
    labels = mx.nd.concatenate(labels, axis=0)
    return evaluate(outputs, labels, val_data._dataset.num_classes(), similarity=similarity,
                    get_detailed_metrics=False, nmi=nmi)
def train_and_evaluate(model,
                       optimizer,
                       train_loader,
                       val_loader,
                       loss_fn,
                       metrics,
                       params,
                       run_dir,
                       device,
                       scheduler=None,
                       restore_file=None,
                       writer=None):
    """
    Train the model and evaluate on every epoch

    Args:
        model: (inherits torch.nn.Module) the custom neural network model
        optimizer: (inherits torch.optim) optimizer to update the model parameters
        train_loader: (DataLoader) a torch.utils.data.DataLoader object that fetches
                      the training set
        val_loader: (DataLoader) a torch.utils.data.DataLoader object that fetches
                    the validation set
        loss_fn : (function) a function that takes batch_output (tensor) and batch_labels
                  (np.ndarray) and return the loss (tensor) over the batch
        metrics: (dict) a dictionary of functions that compute a metric using the
                 batch_output and batch_labels
        params: (Params) hyperparameters
        run_dir: (string) directory containing params.json, learned weights, and logs
        restore_file: (string) optional = name of file to restore training from -> no
                      filename extension .pth or .pth.tar/gz
        writer: (tensorboard) tensorboard summary writer
        device: (str) device type; usually 'cuda:0' or 'cpu'

    """

    # reload the weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(run_dir, restore_file + '.pth.zip')
        if os.path.exists(restore_path):
            logging.info("Restoring weights from {}".format(restore_path))
            load_checkpoint(restore_path, model, optimizer)

    best_val_accu = 0.0

    for epoch in range(params.num_epochs):

        # running one epoch
        logging.info("Epoch {} / {}".format(epoch + 1, params.num_epochs))

        # logging current learning rate
        for i, param_group in enumerate(optimizer.param_groups):
            logging.info("learning rate = {} for parameter group {}".format(
                param_group['lr'], i))

        # train for one full pass over the training set
        train_metrics, batch_summ = train(model, optimizer, loss_fn, train_loader, \
            metrics, params, epoch, device, writer)

        # evaluate for one epoch on the validation set
        val_metrics = evaluate(model, loss_fn, val_loader, metrics, params,
                               device)

        # schedule learning rate
        if scheduler is not None:
            scheduler.step()

        # check if current epoch has best accuracy
        val_accu = val_metrics['accuracy']
        is_best = val_accu >= best_val_accu

        # save weights
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=run_dir)

        # save batch summaries
        save_batch_summary(run_dir, batch_summ)

        # if best accuray
        if is_best:
            logging.info(
                "- Found new best accuray model at epoch {}".format(epoch + 1))
            best_val_accu = val_accu

        # add training log to tensorboard
        if writer is not None:

            # train and validation per-epoch mean metrics
            for metric, value in train_metrics.items():
                if metric in val_metrics.keys():
                    writer.add_scalars(metric, {
                        'train': value,
                        'val': val_metrics[metric]
                    }, epoch)

            # layer weights / gradients distributions
            for idx, m in enumerate(model.modules()):
                if isinstance(m, (nn.Conv2d, nn.Linear)):
                    if m.weight is not None:
                        writer.add_histogram('layer{}.weight'.format(idx),
                                             m.weight, epoch)
                    if m.weight.grad is not None:
                        writer.add_histogram('layer{}.weight.grad'.format(idx), \
                            m.weight.grad, epoch)
Beispiel #8
0
start_time = time.perf_counter()
#remove old reconstruction files
dirpath = pathlib.Path(f'{DEVICE}rec_without')
if os.path.exists(dirpath) and os.path.isdir(dirpath):
    shutil.rmtree(dirpath)
dirpath = pathlib.Path(f'{DEVICE}rec_with')
if os.path.exists(dirpath) and os.path.isdir(dirpath):
    shutil.rmtree(dirpath)

acc=[8]
cf=[0.08]

#train_unet(False,acc,cf,DEVICE)
#run_unet(False,acc,cf,DEVICE)
#no_learning=evaluate(False,DEVICE)

train_unet(True,acc,cf,DEVICE)
run_unet(True,acc,cf,DEVICE)
with_learning=evaluate(True,DEVICE)

tt=time.perf_counter()-start_time
tth=tt/3600
print('****************************************************')
print(f'Device: {DEVICE}')
print(f'Total time: {tth}')
print('****************************************************')
print(f'acceleration {acc[0]},center of {cf[0]}, Regular:')
#print(no_learning)
print('With learning:')
print(with_learning)
print('****************************************************')