示例#1
0
    acc /= times
    return acc, loss


def inference(model, X):  # Test Process
    model.eval()
    with torch.no_grad():
        pred_ = model(torch.from_numpy(X).to(device))
    return pred_.cpu().data.numpy()


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not os.path.exists(args.train_dir):
        os.mkdir(args.train_dir)
    writer = SummaryWriter(
        f"../tensorboard/mlp_{args.drop_rate}_{args.batch_norm}")
    train_stat = Stat(status='train', writer=writer)
    val_stat = Stat(status='val', writer=writer)
    test_stat = Stat(status='test', writer=writer)
    if args.is_train:
        X_train, X_test, y_train, y_test = load_cifar_2d(args.data_dir)
        X_val, y_val = X_train[40000:], y_train[40000:]
        X_train, y_train = X_train[:40000], y_train[:40000]
        mlp_model = Model(drop_rate=args.drop_rate, batch_norm=args.batch_norm)
        mlp_model.to(device)
        print(device)
        print(mlp_model)
        optimizer = optim.Adam(mlp_model.parameters(), lr=args.learning_rate)

        # model_path = os.path.join(args.train_dir, 'checkpoint_%d.pth.tar' % args.inference_version)
        # if os.path.exists(model_path):
    def __init__(self, args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform,
                 output_train_transform, output_val_transform, losses, scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_train_transform, nn.Module) and isinstance(output_val_transform, nn.Module), \
            '`output_train_transform` and `output_val_transform` must be Pytorch Modules.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with a tuple as its output.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=True)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader

        concat_dataset = ConcatDataset([train_loader.dataset, val_loader.dataset])
        self.concat_loader = DataLoader(dataset=concat_dataset, batch_size=args.batch_size, shuffle=True,
                                        num_workers=args.num_workers, collate_fn=temp_collate_fn, pin_memory=False)

        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_train_transform = output_train_transform
        self.output_val_transform = output_val_transform
        self.losses = losses
        self.scheduler = scheduler
        self.writer = SummaryWriter(str(args.log_path))

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.use_slice_metrics = args.use_slice_metrics

        # This part should get SSIM, not 1 - SSIM.
        self.ssim = SSIM(filter_size=7).to(device=args.device)  # Needed to cache the kernel.

        # Logging all components of the Model Trainer.
        # Train and Val input and output transforms are assumed to use the same input transform class.
        self.logger.info(f'''
        Summary of Model Trainer Components:
        Model: {get_class_name(model)}.
        Optimizer: {get_class_name(optimizer)}.
        Input Transforms: {get_class_name(input_val_transform)}.
        Output Transform: {get_class_name(output_val_transform)}.
        RSS Image Domain Loss: {get_class_name(losses['rss_loss'])}.
        Learning-Rate Scheduler: {get_class_name(scheduler)}.
        ''')  # This part has parts different for IMG and CMG losses!!
示例#3
0
def logger_context(
    log_dir,
    run_ID,
    name,
    log_params=None,
    snapshot_mode="none",
    override_prefix=False,
    use_summary_writer=False,
):
    """Use as context manager around calls to the runner's ``train()`` method.
    Sets up the logger directory and filenames.  Unless override_prefix is True,
    this function automatically prepends ``log_dir`` with the rlpyt logging 
    directory and the date: `path-to-rlpyt/data/yyyymmdd` (`data/` is in the
    gitignore), and appends with `/run_{run_ID}` to separate multiple runs of
    the same settings. Saves hyperparameters provided in ``log_params`` to
    `params.json`, along with experiment `name` and `run_ID`.

    Input ``snapshot_mode`` refers to how often the logger actually saves the
    snapshot (e.g. may include agent parameters).  The runner calls on the
    logger to save the snapshot at every iteration, but the input
    ``snapshot_mode`` sets how often the logger actually saves (e.g. snapshot
    may include agent parameters). Possible modes include (but check inside
    the logger itself):
        * "none": don't save at all
        * "last": always save and overwrite the previous
        * "all": always save and keep each iteration
        * "gap": save periodically and keep each (will also need to set the gap, not done here) 

    The cleanup operations after the ``yield`` close files but might not be strictly
    necessary if not launching another training session in the same python process.
    """
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_log_tabular_only(False)
    log_dir = osp.join(log_dir, f"run_{run_ID}")
    exp_dir = osp.abspath(log_dir)
    if LOG_DIR != osp.commonpath([exp_dir, LOG_DIR]) and not override_prefix:
        print(f"logger_context received log_dir outside of {LOG_DIR}: "
              f"prepending by {LOG_DIR}/local/<yyyymmdd>/")
        exp_dir = get_log_dir(log_dir)
    tabular_log_file = osp.join(exp_dir, "progress.csv")
    text_log_file = osp.join(exp_dir, "debug.log")
    params_log_file = osp.join(exp_dir, "params.json")

    logger.set_snapshot_dir(exp_dir)
    if use_summary_writer:
        logger.set_tf_summary_writer(SummaryWriter(exp_dir))
    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    logger.push_prefix(f"{name}_{run_ID} ")

    if log_params is None:
        log_params = dict()
    log_params["name"] = name
    log_params["run_ID"] = run_ID
    with open(params_log_file, "w") as f:
        json.dump(log_params, f)

    yield

    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()
示例#4
0
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 input_train_transform,
                 input_val_transform,
                 output_transform,
                 losses,
                 scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(
            output_transform,
            nn.Module), '`output_transform` must be a Pytorch Module.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with multiple outputs.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model,
                                         optimizer,
                                         mode='min',
                                         save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path,
                                         max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt,
                              load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_transform = output_transform
        self.losses = losses
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.use_slice_metrics = args.use_slice_metrics
        self.writer = SummaryWriter(str(args.log_path))
示例#5
0
    def __init__(
        self,
        dataset_config='yolov3/config/black.data',
        model_config='yolov3/config/yolov3.cfg',
        pretrained_weights=None,
        optimizer='Adam',
        learning_rate=0.001,
        multiscale_training=True,
        batch_size=32,
        n_cpu=8,
        gradient_accumulation=10,
        print_freq=1,
        validation_freq=5,
        n_epoch=100,
        objectness_threshold=0.8,
        nms_threshold=0.4,
        iou_threshold=0.2,
        image_size=416,
        random_seed=None,
        output_dir=configs.save_dir,
    ):
        """
        Args:
            dataset_config (str): path to data config file
            model_config (str): path to model definition file
            pretrained_weights (str): path to a file containing pretrained weights for the model
            optimizer (str): must be a valid class of torch.optim (Adam, SGD, ...)
            learning_rate (float): learning rate fed to the optimizer
            multiscale_training (bool): whether to sample batches with different image sizes
            batch_size (int): size of a training batch
            n_cpu (int): number of workers for the computation of the dataloader
            gradient_accumulation (int): number of gradients from batches to accumulate before a gradient descent
            print_freq (int): inside an epoch, print status update every print_freq episodes
            validation_freq (int): inside an epoch, frequency with which we evaluate the model on the validation set
            n_epoch (int): number of meta-training epochs
            objectness_threshold (float): at evaluation time, only keep boxes with objectness above this threshold
            nms_threshold (float): threshold for non maximum suppression, at evaluation time
            iou_threshold (float): threshold for intersection over union
            image_size (int): size of images (square)
            random_seed (int): seed for random instantiations ; if none is provided, a seed is randomly defined
            output_dir (str): path to experiments output directory
        """

        self.dataset_config = dataset_config
        self.model_config = model_config
        self.pretrained_weights = pretrained_weights
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.multiscale_training = multiscale_training
        self.batch_size = batch_size
        self.n_cpu = n_cpu
        self.gradient_accumulation = gradient_accumulation
        self.print_freq = print_freq
        self.validation_freq = validation_freq
        self.n_epoch = n_epoch
        self.objectness_threshold = objectness_threshold
        self.nms_threshold = nms_threshold
        self.iou_threshold = iou_threshold
        self.image_size = image_size
        self.random_seed = random_seed
        self.checkpoint_dir = output_dir

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.writer = SummaryWriter(log_dir=output_dir)
示例#6
0
def train(config: Dict) -> None:
    make_exp_path(config)

    env = dmc2gym.make(domain_name=config['domain_name'],
                       task_name=config['task_name'],
                       seed=config['seed'],
                       visualize_reward=False,
                       from_pixels=True,
                       height=config['pre_transform_image_size'],
                       width=config['pre_transform_image_size'])
    env = FrameStack(env, config['frame_stack'])

    env.seed(config['seed'])
    seed_all(config['seed'])

    env_params = get_env_params(env)
    env_params.update({
        'obs_shape':
        deepcopy((config['frame_stack'] * 3, config['img_size'],
                  config['img_size']))
    })

    save_config_and_env_params(config, env_params)

    agent = CURL_SACAgent(config, env_params)
    buffer = ReplayBuffer(env_params['preaug_obs_shape'], env_params['a_dim'],
                          config['buffer_size'], config['batch_size'],
                          config['img_size'])

    done = True
    episode_num = 0
    episode_step = 0
    episode_reward = 0
    best_score = 0.
    log_loss = {
        'loss_critic': 0,
        'loss_actor': 0,
        'loss_alpha': 0,
        'loss_cpc': 0,
    }

    logger = SummaryWriter(config['exp_path'])

    for step in range(config['max_timestep']):
        if done:
            obs = env.reset()

            print(
                f"Episode: {episode_num} Total Steps: {step} Episode Steps: {episode_step} Episode Reward: {episode_reward}"
            )
            logger.add_scalar('Indicator/Episode Reward', episode_reward, step)
            for loss_key in list(log_loss.keys()):
                logger.add_scalar(f'Loss/{loss_key}', log_loss[loss_key], step)

            if episode_reward > best_score:
                agent.save(config['exp_path'], 'best')
                best_score = episode_reward

            episode_num += 1
            episode_step = 0
            episode_reward = 0

        if step % config['save_interval'] == 0:
            agent.save(config['exp_path'], str(step))

        action = agent.selection_action(obs, output_mu=False)

        if step > config['train_start_step']:
            log_loss = agent.update(buffer)

        obs_, r, done, _ = env.step(action)
        # allow infinit bootstrap
        #done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
        #    done
        #)
        done_bool = done

        buffer.add(obs, action, r, done_bool, obs_)

        obs = obs_
        episode_step += 1
        episode_reward += r
示例#7
0
def run(runParams: dict) -> float:
    if not runParams["paramsFile"] is None:
        with open(runParams["paramsFile"], "r") as f:
            params = yaml.safe_load(f)
            params["device"] = runParams["device"]
            params["subDeviceIdx"] = runParams["subDeviceIdx"]
    else:
        params = runParams

    if params["device"] != "cuda":
        use_cuda = False
    else:
        use_cuda = torch.cuda.is_available()

    if params["subDeviceIdx"] is None:
        subDeviceIdx = 0
    else:
        subDeviceIdx = params["subDeviceIdx"]

    device = torch.device(
        "cuda:{}".format(subDeviceIdx) if use_cuda else "cpu")
    seed = params["training"]["seed"]
    if seed is None:
        seed = np.random.randint(10000)
        logger.debug("Using random seed")
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed(seed)

    dataList = _asList(params["training"]["datasets"])
    logger.info(f"Datasets: {dataList}")
    nDatasets = len(dataList)

    batchSize = params["training"]["batchSize"]
    epochs = _asList(params["training"]["epochs"], nDatasets)
    extraUnits = _asList(params["training"]["extraUnits"],
                         nDatasets)  # type: List[int]

    dataShape = params["training"]["inputShape"]
    dataSize = dataShape[-2:]
    dataChannels = dataShape[1]

    fullLossFactor = _asList(params["training"]["fullLossCoeff"], nDatasets)
    l1LossFactor = _asList(params["training"]["l1_residuCoeff"], nDatasets)
    l2LossFactor = _asList(params["training"]["l2_residuCoeff"], nDatasets)
    learningRate = _asList(params["training"]["learningRate"], nDatasets)

    expName = _genExpName(dataList)
    experiment = mlflow.get_experiment_by_name(expName)
    if experiment is None:
        logger.info("Creating new experiment")
        expID = mlflow.create_experiment(expName)
    else:
        logger.info(f"Using existing experiment")
        expID = experiment.experiment_id

    contextName = params["name"]
    context = loadContext(contextName)
    resuming = False
    if not context is None:
        resuming = True
        runID = context["runID"]
    else:
        runID = None

    with mlflow.start_run(run_id=runID, experiment_id=expID):
        modelName = params["name"]
        if not resuming:
            mlflow.log_param("params", params)
            mlflow.log_param("name", modelName)

        hrnet = HashRoutedNetwork(modelName, params, device).to(device)
        if resuming:
            hrnet.load_state_dict(context["hrnet"])

        logDir = "../data/logs/" + modelName
        tbWriter = SummaryWriter(logDir)

        if resuming:
            previousTestData = context["pTestData"]
            previousDecoderDataSize = context["pDecoderDataSize"]
            previousClassifiers = []
            for idx, classifierState in enumerate(context["pClassifiers"]):
                oldClassifier = Decoder(
                    _asList(params["routing"]["decoder"], nDatasets)[idx],
                    params["routing"]["embeddingSize"],
                    params["routing"]["basisSize"] * hrnet.NUnits,
                    params["training"]["inputShape"],
                    device,
                    False,
                ).to(device)
                oldClassifier.load_state_dict(classifierState)
                previousClassifiers.append(oldClassifier)
        else:
            previousTestData = []
            previousClassifiers = []
            previousDecoderDataSize = []

        for datasetIdx, dataset in enumerate(dataList):
            if resuming and datasetIdx != context["datasetIdx"]:
                continue

            fullBasisSize = params["routing"]["basisSize"] * hrnet.NUnits
            embeddingSize = params["routing"]["embeddingSize"]

            if USE_PROJ:
                decoderDataSize = 2 * embeddingSize  # + fullBasisSize
            else:
                decoderDataSize = embeddingSize  # 2*embeddingSize #+ fullBasisSize

            # one decoder per dataset
            classifier = Decoder(
                _asList(params["routing"]["decoder"], nDatasets)[datasetIdx],
                embeddingSize,
                fullBasisSize,
                params["training"]["inputShape"],
                device,
                False,
            ).to(device)
            if resuming:
                classifier.load_state_dict(context["classifier"])

            logger.info(f"\n\t==== {dataset}: TRAINING ====\n")
            train_loader, test_loader = getDatasets(dataset, batchSize,
                                                    dataSize, dataChannels)

            optimParams = chain(hrnet.parameters(), classifier.parameters())
            optimizer = optim.Adam(optimParams, lr=learningRate[datasetIdx])

            currentAcc = 0
            currentCorrect = 0
            currentTotalSize = len(test_loader.dataset)
            for epoch in range(1, epochs[datasetIdx] + 1):
                if resuming and epoch != context["epoch"]:
                    continue
                try:
                    train(
                        hrnet,
                        classifier,
                        device,
                        train_loader,
                        optimizer,
                        epoch,
                        fullLossFactor[datasetIdx],
                        l1LossFactor[datasetIdx],
                        l2LossFactor[datasetIdx],
                        tbWriter,
                    )
                    logger.info(f"\n\t==== {dataset}: TEST ({dataset}) ====\n")
                    currentAcc, currentCorrect = test(
                        hrnet,
                        classifier,
                        device,
                        test_loader,
                        len(train_loader),
                        epoch,
                        fullLossFactor[datasetIdx],
                        l1LossFactor[datasetIdx],
                        l2LossFactor[datasetIdx],
                        tbWriter,
                    )
                except RuntimeError as e:
                    if "out of memory" in str(e):
                        logger.warning(
                            f"Runtime error:\n{e}\nSaving models, deleting context and "
                            f"retrying...")
                        saveContext(
                            contextName,
                            hrnet,
                            classifier,
                            previousClassifiers,
                            previousTestData,
                            previousDecoderDataSize,
                            datasetIdx,
                            epoch,
                        )
                    else:
                        raise
                    exit(-1)

            saveModel(hrnet, f"../data/{modelName}_{dataset}.pt")
            saveModel(classifier,
                      f"../data/{modelName}_{dataset}_classifier.pt")

            # trainBuffer = encodeDataset(hrnet, device, train_loader)
            # testBuffer = encodeDataset(hrnet, device, test_loader)
            del train_loader
            # saveData(trainBuffer,f'../data/{modelName}
            # _{dataset}_trainEmbeddings.npz')
            # saveData(testBuffer,f'../data/{modelName}_{dataset}_testEmbeddings.npz')

            totalCorrect = currentCorrect
            totalDatasetSize = currentTotalSize
            for pIdx, pTestData in enumerate(previousTestData):
                previousDataName = dataList[pIdx]
                previousDecodDataSize = previousDecoderDataSize[pIdx]
                logger.info(
                    f"\n\t==== {dataset}: Lifelong TEST  ({previousDataName}) "
                    f"====\n")
                testDatasetSize = len(pTestData.dataset)
                totalDatasetSize += testDatasetSize
                logger.info("Re-encoding test data")
                newEncodedTestData = encodeDataset(hrnet, device, pTestData,
                                                   previousDecodDataSize)
                # saveData(newEncodedTestData,
                #         f'../data/{modelName}_lifelong_{previousDataName}'
                #         f'_testEmbeddings.npz')
                pClassifier = previousClassifiers[pIdx].to(device)

                previousAcc, previousCorrect = testClassifier(
                    pClassifier,
                    newEncodedTestData,
                    testDatasetSize,
                    epochs[datasetIdx],
                    device,
                    f"lifelong/{previousDataName}",
                )

                totalCorrect += previousCorrect

                globalAcc = 100.0 * totalCorrect / totalDatasetSize
                mlflow.log_metric("lifelong/globalAccuracy", globalAcc,
                                  epochs[datasetIdx])
                logger.info(
                    f"Global accuracy at task {pIdx}: {globalAcc:.0f}% "
                    f"({totalCorrect}/{totalDatasetSize})")

            previousTestData.append(test_loader)
            previousClassifiers.append(classifier.cpu())
            previousDecoderDataSize.append(decoderDataSize)

            if datasetIdx < len(dataList) - 1:
                # +1 because the extra units list starts with a 0
                nNewUnits = extraUnits[datasetIdx + 1]
                logger.info(f"Adding {nNewUnits} unit to network")
                previousKoh = hrnet.cpu()
                hrnet = previousKoh.addRndNetworkUnits(nNewUnits)
                del previousKoh
                hrnet = hrnet.to(device)

        mlflow.log_artifacts(logDir, artifact_path="events")

    if resuming:
        os.remove(f"../data/{contextName}_context.pt")

    return currentAcc
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        use_sep: bool = True,
        with_crf: bool = False,
        self_attn: Seq2SeqEncoder = None,
        bert_dropout: float = 0.1,
        sci_sum: bool = False,
        additional_feature_size: int = 0,
    ) -> None:
        super(SeqClassificationModel, self).__init__(vocab)

        self.track_embedding_list = []
        self.track_embedding = {}
        self.text_field_embedder = text_field_embedder
        self.vocab = vocab
        self.use_sep = use_sep
        self.with_crf = with_crf
        self.sci_sum = sci_sum
        self.self_attn = self_attn
        self.additional_feature_size = additional_feature_size

        self.dropout = torch.nn.Dropout(p=bert_dropout)

        # define loss
        if self.sci_sum:
            self.loss = torch.nn.MSELoss(
                reduction='none')  # labels are rouge scores
            self.labels_are_scores = True
            self.num_labels = 1
        else:
            self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                  reduction='none')
            self.labels_are_scores = False
            self.num_labels = self.vocab.get_vocab_size(namespace='labels')
            # define accuracy metrics
            self.label_accuracy = CategoricalAccuracy()
            self.label_f1_metrics = {}

            # define F1 metrics per label
            for label_index in range(self.num_labels):
                label_name = self.vocab.get_token_from_index(
                    namespace='labels', index=label_index)
                self.label_f1_metrics[label_name] = F1Measure(label_index)

        encoded_senetence_dim = text_field_embedder._token_embedders[
            'bert'].output_dim

        ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim(
        )
        ff_in_dim += self.additional_feature_size

        self.time_distributed_aggregate_feedforward = TimeDistributed(
            Linear(ff_in_dim, self.num_labels))

        if self.with_crf:
            self.crf = ConditionalRandomField(
                self.num_labels,
                constraints=None,
                include_start_end_transitions=True)
        self.track_embedding["init_info"] = {
            "ff_in_dim": ff_in_dim,
            "encoded_sentence_dim": encoded_senetence_dim,
            "sci_sum": self.sci_sum,
            "use_sep": self.use_sep,
            "with_crf": self.with_crf,
            "additional_feature_size": self.additional_feature_size
        }
        self.t_board_writer = SummaryWriter()
        self.t_board_writer.add_graph(self)
示例#9
0
    def __init__(self,
                 lr: float,
                 gamma: float,
                 obs_dims,
                 num_actions: int,
                 mem_size,
                 mini_batchsize,
                 epsilon_dec,
                 env_name,
                 algo_name,
                 epsilon=1.0,
                 replace=1000,
                 epsilon_min=0.1,
                 aug_size=4,
                 random_noise_std=0.1,
                 checkpoint_dir='results\\MeanTeacherDDQN'):

        self.lr = lr
        self.gamma = gamma
        self.obs_dims = obs_dims
        self.aug_size = aug_size
        self.num_actions = num_actions
        self.mini_batchsize = mini_batchsize
        self.epsilon_min = epsilon_min
        self.epsilon_dec = epsilon_dec
        self.epsilon = epsilon
        self.replace_target_cnt = replace

        self.mem_counter = 0
        self.copy_counter = 0
        self.checkpoint_dir = checkpoint_dir
        self.memories = ReplayBufferTeacher(mem_size=mem_size,
                                            state_shape=self.obs_dims,
                                            aug_size=aug_size,
                                            num_actions=self.num_actions)
        self.action_space = [i for i in range(self.num_actions)]

        self.learning_network = DeepQNetwork(
            lr=self.lr,
            num_actions=self.num_actions,
            input_dims=self.obs_dims,
            name=algo_name + '_' + env_name + '_' + 'learning',
            checkpoint_dir=self.checkpoint_dir)

        self.target_network = DeepQNetwork(lr=self.lr,
                                           num_actions=self.num_actions,
                                           input_dims=self.obs_dims,
                                           name=env_name + '_' + algo_name +
                                           '_target',
                                           checkpoint_dir=self.checkpoint_dir)

        self.teacher_network = DeepQNetwork(lr=self.lr,
                                            num_actions=self.num_actions,
                                            input_dims=self.obs_dims,
                                            name=algo_name + '_' + env_name +
                                            '_' + 'teacher',
                                            checkpoint_dir=self.checkpoint_dir)

        self.random_noise_std = random_noise_std
        self.online_cost = 0
        self.teacher_cost = 0
        self.writer = SummaryWriter(os.path.join(self.checkpoint_dir, 'logs'))
示例#10
0
def log_values(
    cost,
    grad_norms,
    bl_val,
    epoch,
    batch_id,
    step,
    log_likelihood,
    reinforce_loss,
    bl_loss,
    log_p,
    logger: SummaryWriter,
    args,
):
    avg_cost = cost.mean().item()
    bl_cost = bl_val.mean().item()
    grad_norms, grad_norms_clipped = grad_norms

    # Log values to screen
    print("epoch: {}, train_batch_id: {}, avg_cost: {}, baseline predict: {}".
          format(epoch, batch_id, avg_cost, bl_cost))

    print("grad_norm: {}, clipped: {}".format(grad_norms, grad_norms_clipped))

    # Log values to tensorboard
    logger.add_scalar("avg_cost", avg_cost, step)

    logger.add_scalar("grad_norm", grad_norms[0], step)
    logger.add_scalar("grad_norm_clipped", grad_norms_clipped[0], step)

    logger.add_scalar("actor_loss", reinforce_loss.item(), step)
    logger.add_scalar("nll", -log_likelihood.mean().item(), step)
    # if args.baseline == "critic":
    #     logger.add_scalar("critic_loss", bl_loss.item(), step)
    if not batch_id % 100:
        num_graph, num_step, num_node = log_p.shape
        logger.add_histogram("first_step_prob",
                             log_p.cpu()[0][0].exp().squeeze(), step)
        logger.add_histogram("mid_step_prob",
                             log_p.cpu()[0][num_step // 2].exp().squeeze(),
                             step)
        logger.add_histogram("last_step_prob",
                             log_p.cpu()[0][-1].exp().squeeze(), step)
class SeqClassificationModel(Model):
    """
    Question answering model where answers are sentences
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        use_sep: bool = True,
        with_crf: bool = False,
        self_attn: Seq2SeqEncoder = None,
        bert_dropout: float = 0.1,
        sci_sum: bool = False,
        additional_feature_size: int = 0,
    ) -> None:
        super(SeqClassificationModel, self).__init__(vocab)

        self.track_embedding_list = []
        self.track_embedding = {}
        self.text_field_embedder = text_field_embedder
        self.vocab = vocab
        self.use_sep = use_sep
        self.with_crf = with_crf
        self.sci_sum = sci_sum
        self.self_attn = self_attn
        self.additional_feature_size = additional_feature_size

        self.dropout = torch.nn.Dropout(p=bert_dropout)

        # define loss
        if self.sci_sum:
            self.loss = torch.nn.MSELoss(
                reduction='none')  # labels are rouge scores
            self.labels_are_scores = True
            self.num_labels = 1
        else:
            self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                  reduction='none')
            self.labels_are_scores = False
            self.num_labels = self.vocab.get_vocab_size(namespace='labels')
            # define accuracy metrics
            self.label_accuracy = CategoricalAccuracy()
            self.label_f1_metrics = {}

            # define F1 metrics per label
            for label_index in range(self.num_labels):
                label_name = self.vocab.get_token_from_index(
                    namespace='labels', index=label_index)
                self.label_f1_metrics[label_name] = F1Measure(label_index)

        encoded_senetence_dim = text_field_embedder._token_embedders[
            'bert'].output_dim

        ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim(
        )
        ff_in_dim += self.additional_feature_size

        self.time_distributed_aggregate_feedforward = TimeDistributed(
            Linear(ff_in_dim, self.num_labels))

        if self.with_crf:
            self.crf = ConditionalRandomField(
                self.num_labels,
                constraints=None,
                include_start_end_transitions=True)
        self.track_embedding["init_info"] = {
            "ff_in_dim": ff_in_dim,
            "encoded_sentence_dim": encoded_senetence_dim,
            "sci_sum": self.sci_sum,
            "use_sep": self.use_sep,
            "with_crf": self.with_crf,
            "additional_feature_size": self.additional_feature_size
        }
        self.t_board_writer = SummaryWriter()
        self.t_board_writer.add_graph(self)

    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences
        print(sentences)
        sentences_conv = {}
        for key, val in sentences_conv.items():
            sentences_conv[key] = val.cpu().data.numpy().tolist()
        self.track_embedding["Transformation_0"] = {
            "sentences": sentences_conv
        }
        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        self.track_embedding["Transformation_1"] = {
            "size": list(embedded_sentences.size()),
            "dim": embedded_sentences.dim()
        }

        # Kacper: Basically a padding mask for bert
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = list(embedded_sentences.size())

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            # Kacper: This is an important step where we get SEP tokens to later do sentence classification
            # Kacper: We take a location of SEP tokens from the sentences to get a mask
            sentences_mask = sentences[
                'bert'] == 103  # mask for all the SEP tokens in the batch
            # Kacper: We use this mask to get the respective embeddings from the output layer of bert
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            self.track_embedding["Transformation_2"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: I dont get it why it became 2 instead of 4? What is the difference between size() and dim()???
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # Kacper: comment below is vague
            # Kacper: I think we batch in one array because we just need to compute a mean loss from all of them
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(
                dim=0)  # Kacper: We batch all sentences in one array
            self.track_embedding["Transformation_3"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: Dropout layer is between filtered embeddings and linear layer
            embedded_sentences = self.dropout(embedded_sentences)
            self.track_embedding["Transformation_4"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: we provide the labels for training (for each sentence)
            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                # Kacper: this might be useful to consider in my code as well
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            # Kacper: this shouldnt be the case for our project
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = list(embedded_sentences.size())
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        # Kacper: we unwrap the time dimension of a tensor into the 1st dimension (batch),
        # Kacper: apply a linear layer and wrap the the time dimension back
        # Kacper: I would suspect it is happening only for embeddings related to the [SEP] tokens
        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels
        self.track_embedding["logits"] = {
            "size": list(label_logits.size()),
            "dim": label_logits.dim()
        }
        #print(self.track_embedding)
        self.track_embedding_list.append(deepcopy(self.track_embedding))
        with open(path_json, 'w') as json_out:
            json.dump(self.track_embedding_list, json_out)

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            # Kacper: reshape logits to be of the following shape in view()
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            # Make labels to be contiguous in memory, reshape it so it is in a one dimension
            flattened_gold = labels.contiguous().view(
                -1)  # Kacper: True labels

            if not self.with_crf:
                # Kacper: We are only interested in this part of the code since we don't use crf
                # Kacper: Get a loss (MSE if sci_sum is True or Crossentropy)
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()  # Kacper: Get a mean loss
                # Kacper: Get a probabilities from the logits
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                # Kacper: We are not interested in this if statement branch (for our project)
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                # Kacper: this will be a case for us as well because labels are numerical for Pubmed data
                evaluation_mask = (flattened_gold != -1)
                # Kacper: CategoricalAccuracy is computed in this case
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict

    def get_metrics(self, reset: bool = False):
        # Kacper: this function has to implemented due to API requirements for AllenNLP
        # Kacper: so it can be run automatically with a config file
        metric_dict = {}

        if not self.labels_are_scores:
            type_accuracy = self.label_accuracy.get_metric(reset)
            metric_dict['acc'] = type_accuracy

            average_F1 = 0.0
            for name, metric in self.label_f1_metrics.items():
                metric_val = metric.get_metric(reset)
                metric_dict[name + 'F'] = metric_val[2]
                average_F1 += metric_val[2]

            average_F1 /= len(self.label_f1_metrics.items())
            metric_dict['avgF'] = average_F1

        return metric_dict
示例#12
0
class Agent:
    def __init__(self,
                 env: gym.Env,
                 epoch=5,
                 lr=1e-5,
                 gamma=0.99,
                 epsilon=0.2,
                 lamda=0.98,
                 **kwargs):
        self.env = env
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.network = Network(env.observation_space.shape[-1],
                               env.action_space.shape[-1],
                               **kwargs).to(self.device)
        self.optim = torch.optim.Adam(self.network.parameters(), lr=lr)
        self.policy = Policy(self.network, **kwargs)
        self.writer = SummaryWriter(
            log_dir=f"./runs/ppo-continuous-{datetime.now()}")
        self.global_step, self.train_step = 0, 0
        self.epoch = epoch
        self.epsilon, self.gamma, self.lamda = epsilon, gamma, lamda
        self.memory = []

    def tr(self, x: Iterable) -> torch.Tensor:
        return torch.tensor(np.vstack(x), dtype=torch.float32).to(self.device)

    def get_batch(self) -> Tuple[torch.Tensor, ...]:
        sample = np.array(self.memory).transpose()
        self.memory.clear()
        return tuple(map(self.tr, sample))

    def gae(self, delta):
        adv_lst = []
        adv = 0
        for i in reversed(range(delta.shape[0])):
            adv = self.gamma * self.lamda * adv + delta[i]
            adv_lst.append(adv)
        adv_lst.reverse()
        return torch.tensor(adv_lst, dtype=torch.float32).reshape_as(delta)

    def train(self):
        s, a, r, ns, d, p = self.get_batch()
        for eph in range(self.epoch):
            target = r + (1 - d) * self.gamma * self.network.forward_value(ns)
            value = self.network.forward_value(s)
            value_loss = F.smooth_l1_loss(value, target)
            delta = target - value
            advantage = self.gae(delta)
            mu, std = self.network.forward_policy(s)
            new_action_dist = Normal(mu, std)
            new_p = new_action_dist.log_prob(a)
            ratio = torch.exp(new_p - p)
            policy_loss = ratio * advantage
            policy_loss_clipped = torch.clamp(
                ratio, min=1 - self.epsilon, max=1 + self.epsilon) * advantage
            policy_loss = -torch.min(policy_loss_clipped,
                                     policy_loss)[0].mean()
            self.optim.zero_grad()
            (policy_loss + value_loss).backward()
            self.optim.step()
            self.writer.add_scalar("training/loss/policy", policy_loss.item(),
                                   self.train_step)
            self.writer.add_scalar("training/loss/policy_abs",
                                   policy_loss.abs().item(), self.train_step)
            self.writer.add_scalar("training/loss/value", value_loss.item(),
                                   self.train_step)
            self.writer.add_scalar("training/1-ratio_abs",
                                   (1 - ratio).mean().abs().item(),
                                   self.train_step)
            self.writer.add_scalar("training/advantage",
                                   advantage.mean().item(), self.train_step)
            self.train_step += 1

    def __call__(self,
                 n_epi=10000,
                 sample_size=30,
                 train=True,
                 objective=-350):
        scores = deque(maxlen=10)
        for epi in range(n_epi):
            sc = 0
            s = self.env.reset()
            d = False
            while not d:
                for _ in range(sample_size):
                    a, p, mean, std = self.policy.get_action(self.tr([s]))
                    ns, r, d, _ = self.env.step(a)
                    self.memory.append((s, a, [r / 100], ns, [d], [p]))
                    sc += r
                    s = ns
                    if d: break
                    self.writer.add_scalar("performance/reward", r,
                                           self.global_step)
                    self.writer.add_scalar("action/action", a.squeeze(),
                                           self.global_step)
                    self.writer.add_scalar("action/mean", mean,
                                           self.global_step)
                    self.writer.add_scalar("action/stddev", std,
                                           self.global_step)
                    self.writer.add_scalar("action/prob", np.exp(p),
                                           self.global_step)
                    self.global_step += 1
                if train: self.train()
                else:
                    self.env.render()
                    print("score:", sc)
            self.writer.add_scalar("performance/score", sc, epi)
            scores.append(sc)
            if np.mean(scores) >= objective: break
示例#13
0
文件: run.py 项目: WenjinW/LLSEU
########################################################################################################################
# define logger
logger = logging.getLogger()

# define tensorboard writer
exp_name = args.experiment+'_'+args.approach+'_'+str(args.seed)+'_'+args.id

# in polyaxon
if args.location == 'polyaxon':
    from polyaxon_client.tracking import Experiment
    experiment = Experiment()
    output_path = experiment.get_outputs_path()
    print("Output path: {}".format(output_path))
    logger.info("Output path: {}".format(output_path))
    writer = SummaryWriter(log_dir='/'+output_path)
else:
    writer = SummaryWriter(log_dir='../logs/' + exp_name)

# Load date
print('Load data...')

data, taskcla, inputsize = dataloader.get(path='../dat/', seed=args.seed)
print('Input size =', inputsize, '\nTask info =', taskcla)
logger.info('Input size =', inputsize, '\nTask info =', taskcla)


# logging the experiment config
config_exp = ["name: {}".format(exp_name),
              "mode: {}".format(args.mode),
              "dataset: {}".format(args.experiment),
示例#14
0
 def __init__(self, log_dir: str = None) -> None:
     self._writer = SummaryWriter(log_dir=log_dir)
示例#15
0
    def __init__(self,
                 workflow_id,
                 phase,
                 artifacts_dir,
                 comparison_pairs=None):
        """
        This method instantiates an object of type TensorboardSummaryHook. The signature of this method is similar to
        that of every other hook. There is one additional parameter called `comparison_pairs` which is meant to
        hold a list of lists each containing a pair of input/output names that share the same dimensionality and can be
        compared to each other.

        A typical use of `comparison_pairs` is when users want to plot a pr_curve or a confusion matrix by comparing
        some input with some output. Eg. by comparing the labels with the predictions.

        .. code-block:: python

            from eisen.utils.logging import TensorboardSummaryHook

            workflow = # Eg. An instance of Training workflow

            logger = TensorboardSummaryHook(
                workflow_id=workflow.id,
                phase='Training',
                artifacts_dir='/artifacts/dir'
                comparison_pairs=[['labels', 'predictions']]
            )

        :param workflow_id: string containing the workflow id of the workflow being monitored (workflow_instance.id)
        :type workflow_id: UUID
        :param phase: string containing the name of the phase (training, testing, ...) of the workflow monitored
        :type phase: str
        :param artifacts_dir: whether the history of all models that were at a certain point the best should be saved
        :type artifacts_dir: bool
        :param comparison_pairs: list of lists of pairs, which are names of inputs and outputs to be compared directly
        :type comparison_pairs: list of lists of strings

        <json>
        [
            {"name": "comparison_pairs", "type": "list:list:string", "value": ""}
        ]
        </json>
        """
        self.workflow_id = workflow_id
        self.phase = phase

        self.comparison_pairs = comparison_pairs

        if not os.path.exists(artifacts_dir):
            raise ValueError(
                'The directory specified to save artifacts does not exist!')

        dispatcher.connect(self.end_epoch,
                           signal=EISEN_END_EPOCH_EVENT,
                           sender=workflow_id)

        self.artifacts_dir = os.path.join(artifacts_dir, 'summaries', phase)

        if not os.path.exists(self.artifacts_dir):
            os.makedirs(self.artifacts_dir)

        self.writer = SummaryWriter(log_dir=self.artifacts_dir)
示例#16
0
class MeanTeacherAgent(
):  # Double DQN with Mean Teacher semi-supervised learning strategy
    def __init__(self,
                 lr: float,
                 gamma: float,
                 obs_dims,
                 num_actions: int,
                 mem_size,
                 mini_batchsize,
                 epsilon_dec,
                 env_name,
                 algo_name,
                 epsilon=1.0,
                 replace=1000,
                 epsilon_min=0.1,
                 aug_size=4,
                 random_noise_std=0.1,
                 checkpoint_dir='results\\MeanTeacherDDQN'):

        self.lr = lr
        self.gamma = gamma
        self.obs_dims = obs_dims
        self.aug_size = aug_size
        self.num_actions = num_actions
        self.mini_batchsize = mini_batchsize
        self.epsilon_min = epsilon_min
        self.epsilon_dec = epsilon_dec
        self.epsilon = epsilon
        self.replace_target_cnt = replace

        self.mem_counter = 0
        self.copy_counter = 0
        self.checkpoint_dir = checkpoint_dir
        self.memories = ReplayBufferTeacher(mem_size=mem_size,
                                            state_shape=self.obs_dims,
                                            aug_size=aug_size,
                                            num_actions=self.num_actions)
        self.action_space = [i for i in range(self.num_actions)]

        self.learning_network = DeepQNetwork(
            lr=self.lr,
            num_actions=self.num_actions,
            input_dims=self.obs_dims,
            name=algo_name + '_' + env_name + '_' + 'learning',
            checkpoint_dir=self.checkpoint_dir)

        self.target_network = DeepQNetwork(lr=self.lr,
                                           num_actions=self.num_actions,
                                           input_dims=self.obs_dims,
                                           name=env_name + '_' + algo_name +
                                           '_target',
                                           checkpoint_dir=self.checkpoint_dir)

        self.teacher_network = DeepQNetwork(lr=self.lr,
                                            num_actions=self.num_actions,
                                            input_dims=self.obs_dims,
                                            name=algo_name + '_' + env_name +
                                            '_' + 'teacher',
                                            checkpoint_dir=self.checkpoint_dir)

        self.random_noise_std = random_noise_std
        self.online_cost = 0
        self.teacher_cost = 0
        self.writer = SummaryWriter(os.path.join(self.checkpoint_dir, 'logs'))

    def decrement_epsilon(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon = self.epsilon - self.epsilon_dec
        else:
            self.epsilon = self.epsilon_min

    def augment_state(self, obs):
        aug_states = []
        for i in range(self.aug_size):
            aug_states.append(random_noise(obs, var=self.random_noise_std**2))
        return aug_states

    def store_memory(self, obs, action, reward, new_obs, done):
        aug_states = self.augment_state(obs)
        self.memories.store(obs, aug_states, action, reward, new_obs, done)
        self.mem_counter += 1

    def sample_memory(self):
        _states, _aug_states, _actions, _rewards, _new_states, _dones = self.memories.sample(
            self.mini_batchsize)

        _states = T.tensor(_states)
        _aug_states = T.tensor(_aug_states)
        _actions = T.tensor(_actions)
        _rewards = T.tensor(_rewards)
        _new_states = T.tensor(_new_states)
        _dones = T.tensor(_dones)

        states = T.tensor(_states).to(self.target_network.device)
        actions = T.tensor(_actions).to(self.target_network.device)
        rewards = T.tensor(_rewards).to(self.target_network.device)
        new_states = T.tensor(_new_states).to(self.target_network.device)
        dones = T.tensor(_dones).to(self.target_network.device)

        # Flattens augmented states into batchsize of minibatch_size * aug_size, the others are repeated
        _aug_states = _aug_states.reshape(
            (_aug_states.shape[0] * _aug_states.shape[1], *self.obs_dims))
        teacher_states = _aug_states.to(self.target_network.device)
        teacher_actions = T.tensor(
            T.repeat_interleave(_actions.clone(), repeats=self.aug_size,
                                dim=0)).to(self.target_network.device)
        teacher_rewards = T.tensor(
            T.repeat_interleave(_rewards.clone(), repeats=self.aug_size,
                                dim=0)).to(self.target_network.device)
        teacher_new_states = T.tensor(
            T.repeat_interleave(_new_states.clone(),
                                repeats=self.aug_size,
                                dim=0)).to(self.target_network.device)
        teacher_dones = T.tensor(
            T.repeat_interleave(_dones.clone(), repeats=self.aug_size,
                                dim=0)).to(self.target_network.device)

        # print(f'---Teacher States shape: {teacher_states.size()}')
        return states, actions, rewards, new_states, dones, teacher_states, teacher_actions, teacher_rewards, teacher_new_states, teacher_dones

    def get_action(self, obs):
        if np.random.random() < self.epsilon:
            action = np.random.choice(len(self.action_space), 1)[0]
        else:
            state = T.tensor([obs],
                             dtype=T.float).to(self.learning_network.device)

            returns_for_actions = self.target_network.forward(state)
            action = T.argmax(returns_for_actions).cpu().detach().numpy()
        return action

    def learn(self):
        if self.mem_counter < self.mini_batchsize:
            return

        self.learning_network.optimizer.zero_grad()
        self.teacher_network.optimizer.zero_grad()
        states, actions, rewards, new_states, dones, teacher_states, teacher_actions, teacher_rewards, teacher_new_states, teacher_dones = self.sample_memory(
        )

        # === Learning for online net === #
        indices = np.arange(self.mini_batchsize)
        q_pred = self.learning_network.forward(states)[indices, actions]

        q_next = self.learning_network.forward(new_states)
        actions_selected = T.argmax(
            q_next, dim=1)  # Action selection based on online weights

        q_eval = self.target_network.forward(new_states)
        q_eval[dones] = 0.0  #Actions' return value are evaluated

        q_target = rewards + self.gamma * q_eval[indices, actions_selected]
        online_cost = self.learning_network.loss(q_target, q_pred)
        online_cost.backward()
        self.learning_network.optimizer.step()

        # ==== Learning for teacher net ==== #
        indices = np.arange(self.mini_batchsize * self.aug_size)
        q_pred = self.teacher_network.forward(teacher_states)[indices,
                                                              teacher_actions]

        q_next = self.teacher_network.forward(teacher_new_states)
        actions_selected = T.argmax(
            q_next, dim=1)  # Action selection based on online weights

        q_eval = self.target_network.forward(teacher_new_states)
        q_eval[teacher_dones] = 0.0  #Actions' return value are evaluated

        q_target = teacher_rewards + self.gamma * q_eval[indices,
                                                         actions_selected]
        teacher_cost = self.teacher_network.loss(q_target, q_pred)
        teacher_cost.backward()
        self.teacher_network.optimizer.step()

        self.decrement_epsilon()
        if self.copy_counter % self.replace_target_cnt == 0:
            self.copy_target_network()
        self.copy_counter += 1

        self.online_cost = online_cost
        self.teacher_cost = teacher_cost

    def log(self, num_episode):
        diff = 0
        for p_online, p_teacher in zip(self.learning_network.parameters(),
                                       self.teacher_network.parameters()):
            p_online = p_online.data.cpu()
            p_teacher = p_teacher.data.cpu()
            diff += T.sum(p_online - p_teacher)

        self.writer.add_scalar("Online td_error", self.online_cost,
                               num_episode)
        self.writer.add_scalar("Teacher td_error", self.teacher_cost,
                               num_episode)
        self.writer.add_scalar("online_teacher_diff", diff, num_episode)

        return diff

    def copy_target_network(self):
        self.target_network.load_state_dict(self.learning_network.state_dict())

    def save_models(self):
        self.learning_network.save()
        self.target_network.save()

    def load_models(self):
        self.learning_network.load()
        self.target_network.load()
示例#17
0
def train(cfg):

    #load data
    train_dataset = dataset.KITTIDataset(root=cfg.PATH, cfg=cfg, mode='train')
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=cfg.TRAIN_BATCH,
                                  shuffle=True)

    # Initialize the model
    mobilenetv2 = MobileNetV2()
    model = Model(features=mobilenetv2.features, bins=cfg.BIN).to(cfg.DEVICE)

    # If specified we start from pth
    model_list = [
        x for x in sorted(os.listdir(cfg.MODEL_DIR)) if x.endswith(".pth")
    ]
    # model_list=[]
    if not model_list:
        print("No previous model found, start training!")
        mobilenetv2_model = torch.load('./model/mobilenet_v2.pth.tar')
        mobilenetv2.load_state_dict(mobilenetv2_model)
    else:
        print("Find previous model %s" % model_list[-1])
        # model.load_state_dict(torch.load(cfg.MODEL_DIR + "/%s" % model_list[-1], map_location=torch.device(cfg.DEVICE)))

    opt_SGD = torch.optim.SGD(model.parameters(), lr=0.001)

    writer = SummaryWriter(cfg.LOG_DIR)
    for i in range(cfg.EPOCH):
        model.train()
        print('Epoch %d' % i)
        loss_epoch = 0
        conf_loss_epoch = 0
        orient_loss_epoch = 0
        for batch_idx, (inputs, labels) in enumerate(train_dataloader):
            inputs = inputs.to(cfg.DEVICE)
            confidence = labels['confidence'].to(cfg.DEVICE)
            angle_offset = labels['angle_offset'].to(cfg.DEVICE)

            [orient, conf] = model(inputs)

            conf_loss = binary_cross_entropy_one_hot(conf, confidence)
            orient_loss = OrientationLoss(orient, angle_offset)
            loss = conf_loss + cfg.WEIGHT * orient_loss
            loss_epoch += loss
            conf_loss_epoch += conf_loss
            orient_loss_epoch += orient_loss

            opt_SGD.zero_grad()
            loss.backward()
            opt_SGD.step()

            print('Batch %d' % batch_idx)
            print('Training loss: ', loss.item())
            print('Confidence loss: ', conf_loss.item())
            print('Orientation loss: ', orient_loss.item())

        loss_epoch = loss_epoch / (batch_idx + 1)
        conf_loss_epoch = conf_loss_epoch / (batch_idx + 1)
        orient_loss_epoch = orient_loss_epoch / (batch_idx + 1)

        # print('Batch %d'%batch_idx)
        # print('Training loss: ', loss.item())
        # print('Confidence loss: ', conf_loss.item())
        # print('Orientation loss: ', orient_loss.item())
        writer.add_scalar('Training loss: ', loss_epoch.item(), i)
        writer.add_scalar('Confidence loss: ', conf_loss_epoch.item(), i)
        writer.add_scalar('Orientation  loss: ', orient_loss_epoch.item(), i)

        #log process
        pass

        if i % 10 == 0:
            now = datetime.datetime.now()
            now_s = now.strftime("%Y-%m-%d-%H-%M-%S")
            name = cfg.MODEL_DIR + "/{}_model_{}.pth".format(cfg.BIN, now_s)
            torch.save(model.state_dict(), name)

    return
示例#18
0
                              shuffle=True,
                              num_workers=8)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8)
    test_loader = DataLoader(test_set,
                             batch_size=test_batch_size,
                             shuffle=False,
                             num_workers=8)

    # Run the experiments
    for seed in seeds:
        logger.info('Train model with seed {}'.format(seed))
        # TensorboardX writer
        writer = SummaryWriter(main_directory + '/runs/' + experiment_name +
                               '_' + str(seed))

        # The model
        torch.manual_seed(seed)
        model = WideNet().to(device)
        logger.info('Net parameters number : {}'.format(
            utils.compute_total_parameter_number(model)))

        optimizer = optim.SGD(model.parameters(),
                              lr=0.05,
                              momentum=0.9,
                              weight_decay=0.001)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=[50, 100, 150],
                                                   gamma=0.1)
示例#19
0
class GeneralTrainer(Trainer):
    def __init__(self, config_path: str, model_kind: str) -> None:
        self.config_path = config_path
        self.model_kind = model_kind
        self.config_manager = ConfigManager(config_path=config_path,
                                            model_kind=model_kind)
        self.config = self.config_manager.config
        self.losses = []
        self.lr = 0
        self.pad_idx = 0
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx)
        self.set_device()

        self.config_manager.create_remove_dirs()
        self.text_encoder = self.config_manager.text_encoder
        self.start_symbol_id = self.text_encoder.start_symbol_id
        self.summary_manager = SummaryWriter(
            log_dir=self.config_manager.log_dir)

        self.model = self.config_manager.get_model()

        self.optimizer = self.get_optimizer()
        self.model = self.model.to(self.device)

        self.load_model(model_path=self.config.get("train_resume_model_path"))
        self.load_diacritizer()

        self.initialize_model()

        self.print_config()

    def set_device(self):
        if self.config.get("device"):
            self.device = self.config["device"]
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def print_config(self):
        self.config_manager.dump_config()
        self.config_manager.print_config()

        if self.global_step > 1:
            print(f"loaded form {self.global_step}")

        parameters_count = count_parameters(self.model)
        print(
            f"The model has {parameters_count} trainable parameters parameters"
        )

    def load_diacritizer(self):
        if self.model_kind in ["cbhg", "baseline"]:
            self.diacritizer = CBHGDiacritizer(self.config_path,
                                               self.model_kind)
        elif self.model_kind in ["seq2seq", "tacotron_based"]:
            self.diacritizer = Seq2SeqDiacritizer(self.config_path,
                                                  self.model_kind)

    def initialize_model(self):
        if self.global_step > 1:
            return
        if self.model_kind == "transformer":
            print("Initializing using xavier_uniform_")
            self.model.apply(initialize_weights)

    def print_losses(self, step_results, tqdm):
        self.summary_manager.add_scalar("loss/loss",
                                        step_results["loss"],
                                        global_step=self.global_step)

        tqdm.display(f"loss: {step_results['loss']}", pos=3)
        for pos, n_steps in enumerate(self.config["n_steps_avg_losses"]):
            if len(self.losses) > n_steps:

                self.summary_manager.add_scalar(
                    f"loss/loss-{n_steps}",
                    sum(self.losses[-n_steps:]) / n_steps,
                    global_step=self.global_step,
                )
                tqdm.display(
                    f"{n_steps}-steps average loss: {sum(self.losses[-n_steps:]) / n_steps}",
                    pos=pos + 4,
                )

    def evaluate(self, iterator, tqdm, use_target=True):
        epoch_loss = 0
        epoch_acc = 0
        self.model.eval()
        tqdm.set_description(f"Eval: {self.global_step}")
        with torch.no_grad():
            for batch_inputs in iterator:
                batch_inputs["src"] = batch_inputs["src"].to(self.device)
                batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu")
                if use_target:
                    batch_inputs["target"] = batch_inputs["target"].to(
                        self.device)
                else:
                    batch_inputs["target"] = None

                outputs = self.model(
                    src=batch_inputs["src"],
                    target=batch_inputs["target"],
                    lengths=batch_inputs["lengths"],
                )

                predictions = outputs["diacritics"]

                predictions = predictions.view(-1, predictions.shape[-1])
                targets = batch_inputs["target"]
                targets = targets.view(-1)
                loss = self.criterion(predictions, targets.to(self.device))
                acc = categorical_accuracy(predictions,
                                           targets.to(self.device),
                                           self.pad_idx)
                epoch_loss += loss.item()
                epoch_acc += acc.item()
                tqdm.update()

        tqdm.reset()
        return epoch_loss / len(iterator), epoch_acc / len(iterator)

    def evaluate_with_error_rates(self, iterator, tqdm):
        all_orig = []
        all_predicted = []
        results = {}
        self.diacritizer.set_model(self.model)
        evaluated_batches = 0
        tqdm.set_description(f"Calculating DER/WER {self.global_step}: ")
        for batch in iterator:
            if evaluated_batches > int(self.config["error_rates_n_batches"]):
                break

            predicted = self.diacritizer.diacritize_batch(batch)
            all_predicted += predicted
            all_orig += batch["original"]
            tqdm.update()

        summary_texts = []
        orig_path = os.path.join(self.config_manager.prediction_dir,
                                 f"original.txt")
        predicted_path = os.path.join(self.config_manager.prediction_dir,
                                      f"predicted.txt")

        with open(orig_path, "w", encoding="utf8") as file:
            for sentence in all_orig:
                file.write(f"{sentence}\n")

        with open(predicted_path, "w", encoding="utf8") as file:
            for sentence in all_predicted:
                file.write(f"{sentence}\n")

        for i in range(int(self.config["n_predicted_text_tensorboard"])):
            if i > len(all_predicted):
                break

            summary_texts.append(
                (f"eval-text/{i}", f"{ all_orig[i]} |->  {all_predicted[i]}"))

        results["DER"] = der.calculate_der_from_path(orig_path, predicted_path)
        results["DER*"] = der.calculate_der_from_path(orig_path,
                                                      predicted_path,
                                                      case_ending=False)
        results["WER"] = wer.calculate_wer_from_path(orig_path, predicted_path)
        results["WER*"] = wer.calculate_wer_from_path(orig_path,
                                                      predicted_path,
                                                      case_ending=False)
        tqdm.reset()
        return results, summary_texts

    def run(self):
        scaler = torch.cuda.amp.GradScaler()
        train_iterator, _, validation_iterator = load_iterators(
            self.config_manager)
        print("data loaded")
        print("----------------------------------------------------------")
        tqdm_eval = trange(0, len(validation_iterator), leave=True)
        tqdm_error_rates = trange(0, len(validation_iterator), leave=True)
        tqdm_eval.set_description("Eval")
        tqdm_error_rates.set_description("WER/DER : ")
        tqdm = trange(self.global_step,
                      self.config["max_steps"] + 1,
                      leave=True)

        for batch_inputs in repeater(train_iterator):
            tqdm.set_description(f"Global Step {self.global_step}")
            if self.config["use_decay"]:
                self.lr = self.adjust_learning_rate(
                    self.optimizer, global_step=self.global_step)
            self.optimizer.zero_grad()
            if self.device == "cuda" and self.config["use_mixed_precision"]:
                with autocast():
                    step_results = self.run_one_step(batch_inputs)
                    scaler.scale(step_results["loss"]).backward()
                    scaler.unscale_(self.optimizer)
                    if self.config.get("CLIP"):
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                       self.config["CLIP"])

                    scaler.step(self.optimizer)

                    scaler.update()
            else:
                step_results = self.run_one_step(batch_inputs)

                loss = step_results["loss"]
                loss.backward()
                if self.config.get("CLIP"):
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.config["CLIP"])
                self.optimizer.step()

            self.losses.append(step_results["loss"].item())

            self.print_losses(step_results, tqdm)

            self.summary_manager.add_scalar("meta/learning_rate",
                                            self.lr,
                                            global_step=self.global_step)

            if self.global_step % self.config["model_save_frequency"] == 0:
                torch.save(
                    {
                        "global_step": self.global_step,
                        "model_state_dict": self.model.state_dict(),
                        "optimizer_state_dict": self.optimizer.state_dict(),
                    },
                    os.path.join(
                        self.config_manager.models_dir,
                        f"{self.global_step}-snapshot.pt",
                    ),
                )

            if self.global_step % self.config["evaluate_frequency"] == 0:
                loss, acc = self.evaluate(validation_iterator, tqdm_eval)
                self.summary_manager.add_scalar("evaluate/loss",
                                                loss,
                                                global_step=self.global_step)
                self.summary_manager.add_scalar("evaluate/acc",
                                                acc,
                                                global_step=self.global_step)
                tqdm.display(
                    f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}",
                    pos=8)
                self.model.train()

            if (self.global_step %
                    self.config["evaluate_with_error_rates_frequency"] == 0):
                error_rates, summery_texts = self.evaluate_with_error_rates(
                    validation_iterator, tqdm_error_rates)
                if error_rates:
                    WER = error_rates["WER"]
                    DER = error_rates["DER"]
                    DER1 = error_rates["DER*"]
                    WER1 = error_rates["WER*"]

                    self.summary_manager.add_scalar(
                        "error_rates/WER",
                        WER / 100,
                        global_step=self.global_step,
                    )
                    self.summary_manager.add_scalar(
                        "error_rates/DER",
                        DER / 100,
                        global_step=self.global_step,
                    )
                    self.summary_manager.add_scalar(
                        "error_rates/DER*",
                        DER1 / 100,
                        global_step=self.global_step,
                    )
                    self.summary_manager.add_scalar(
                        "error_rates/WER*",
                        WER1 / 100,
                        global_step=self.global_step,
                    )

                    error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}"
                    tqdm.display(f"WER/DER {self.global_step}: {error_rates}",
                                 pos=9)

                    for tag, text in summery_texts:
                        self.summary_manager.add_text(tag, text)

                self.model.train()

            if self.global_step % self.config["train_plotting_frequency"] == 0:
                self.plot_attention(step_results)

            self.report(step_results, tqdm)

            self.global_step += 1
            if self.global_step > self.config["max_steps"]:
                print("Training Done.")
                return

            tqdm.update()

    def run_one_step(self, batch_inputs: Dict[str, torch.Tensor]):
        batch_inputs["src"] = batch_inputs["src"].to(self.device)
        batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu")
        batch_inputs["target"] = batch_inputs["target"].to(self.device)

        outputs = self.model(
            src=batch_inputs["src"],
            target=batch_inputs["target"],
            lengths=batch_inputs["lengths"],
        )

        predictions = outputs["diacritics"].contiguous()
        targets = batch_inputs["target"].contiguous()
        predictions = predictions.view(-1, predictions.shape[-1])
        targets = targets.view(-1)
        loss = self.criterion(predictions.to(self.device),
                              targets.to(self.device))
        outputs.update({"loss": loss})
        return outputs

    def predict(self, iterator):
        pass

    def load_model(self, model_path: str = None, load_optimizer: bool = True):
        with open(
                self.config_manager.base_dir /
                f"{self.model_kind}_network.txt", "w") as file:
            file.write(str(self.model))

        if model_path is None:
            last_model_path = self.config_manager.get_last_model_path()
            if last_model_path is None:
                self.global_step = 1
                return
        else:
            last_model_path = model_path

        print(f"loading from {last_model_path}")
        saved_model = torch.load(last_model_path)
        self.model.load_state_dict(saved_model["model_state_dict"])
        if load_optimizer:
            self.optimizer.load_state_dict(saved_model["optimizer_state_dict"])
        self.global_step = saved_model["global_step"] + 1

    def get_optimizer(self):
        if self.config["optimizer"] == OptimizerType.Adam:
            optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.config["learning_rate"],
                betas=(self.config["adam_beta1"], self.config["adam_beta2"]),
                weight_decay=self.config["weight_decay"],
            )
        elif self.config["optimizer"] == OptimizerType.SGD:
            optimizer = optim.SGD(self.model.parameters(),
                                  lr=self.config["learning_rate"],
                                  momentum=0.9)
        else:
            raise ValueError("Optimizer option is not valid")

        return optimizer

    def get_learning_rate(self):
        return LearningRateDecay(
            lr=self.config["learning_rate"],
            warmup_steps=self.config.get("warmup_steps", 4000.0),
        )

    def adjust_learning_rate(self, optimizer, global_step):
        learning_rate = self.get_learning_rate()(global_step=global_step)
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate
        return learning_rate

    def plot_attention(self, results):
        pass

    def report(self, results, tqdm):
        pass
示例#20
0
class YOLOMAMLTraining():
    """
    This step handles the training of the algorithm on the base dataset
    """

    def __init__(
            self,
            dataset_config='yolov3/config/black.data',
            model_config='yolov3/config/yolov3.cfg',
            pretrained_weights=None,
            n_way=5,
            n_shot=5,
            n_query=16,
            optimizer='Adam',
            learning_rate=0.001,
            approx=True,
            task_update_num=3,
            print_freq=100,
            validation_freq=1000,
            n_epoch=100,
            n_episode=100,
            objectness_threshold=0.8,
            nms_threshold=0.4,
            iou_threshold=0.2,
            image_size=416,
            random_seed=None,
            output_dir=configs.save_dir,
    ):
        """
        Args:
            dataset_config (str): path to data config file
            model_config (str): path to model definition file
            pretrained_weights (str): path to a file containing pretrained weights for the model
            n_way (int): number of labels in a detection task
            n_shot (int): number of support data in each class in an episode
            n_query (int): number of query data in each class in an episode
            optimizer (str): must be a valid class of torch.optim (Adam, SGD, ...)
            learning_rate (float): learning rate fed to the optimizer
            approx (bool): whether to use an approximation of the meta-backpropagation
            task_update_num (int): number of updates inside each episode
            print_freq (int): inside an epoch, print status update every print_freq episodes
            validation_freq (int): inside an epoch, frequency with which we evaluate the model on the validation set
            n_epoch (int): number of meta-training epochs
            n_episode (int): number of episodes per epoch during meta-training
            objectness_threshold (float): at evaluation time, only keep boxes with objectness above this threshold
            nms_threshold (float): threshold for non maximum suppression, at evaluation time
            iou_threshold (float): threshold for intersection over union
            image_size (int): size of images (square)
            random_seed (int): seed for random instantiations ; if none is provided, a seed is randomly defined
            output_dir (str): path to experiments output directory
        """

        self.dataset_config = dataset_config
        self.model_config = model_config
        self.pretrained_weights = pretrained_weights
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.approx = approx
        self.task_update_num = task_update_num
        self.print_freq = print_freq
        self.validation_freq = validation_freq
        self.n_epoch = n_epoch
        self.n_episode = n_episode
        self.objectness_threshold = objectness_threshold
        self.nms_threshold = nms_threshold
        self.iou_threshold = iou_threshold
        self.image_size = image_size
        self.random_seed = random_seed
        self.checkpoint_dir = output_dir

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.writer = SummaryWriter(log_dir=output_dir)

    def apply(self):
        """
        Execute the YOLOMAMLTraining step
        Returns:
            dict: a dictionary containing the whole state of the model that gave the higher validation accuracy

        """
        set_and_print_random_seed(self.random_seed, True, self.checkpoint_dir)

        data_config = parse_data_config(self.dataset_config)
        train_path = data_config["train"]
        train_dict_path = data_config.get("train_dict_path", None)
        valid_path = data_config.get("valid", None)
        valid_dict_path = data_config.get("valid_dict_path", None)

        base_loader = self._get_data_loader(train_path, train_dict_path)
        val_loader = self._get_data_loader(valid_path, valid_dict_path)

        model = self._get_model()

        return self._train(base_loader, val_loader, model)

    def dump_output(self, _, output_folder, output_name, **__):
        pass

    def _train(self, base_loader, val_loader, model):
        """
        Trains the model on the base set
        Args:
            base_loader (torch.utils.data.DataLoader): data loader for base set
            val_loader (torch.utils.data.DataLoader): data loader for validation set
            model (YOLOMAML): neural network model to train

        Returns:
            dict: a dictionary containing the whole state of the model that gave the higher validation accuracy

        """
        optimizer = self._get_optimizer(model)

        for epoch in range(self.n_epoch):
            loss_dict = model.train_loop(base_loader, optimizer)

            self.plot_tensorboard(loss_dict, epoch)

            if epoch % self.print_freq == 0:
                print(
                    'Epoch {epoch}/{n_epochs} | Loss {loss}'.format(
                        epoch=epoch,
                        n_epochs=self.n_epoch,
                        loss=loss_dict['query_total_loss'],
                    )
                )

            if epoch % self.validation_freq == self.validation_freq - 1:
                precision, recall, average_precision, f1, ap_class = model.eval_loop(val_loader)

                self.writer.add_scalar('precision', precision.mean(), epoch)
                self.writer.add_scalar('recall', recall.mean(), epoch)
                self.writer.add_scalar('mAP', average_precision.mean(), epoch)
                self.writer.add_scalar('F1', f1.mean(), epoch)

        self.writer.close()

        model.base_model.save_darknet_weights(os.path.join(self.checkpoint_dir, 'final.weights'))

        return {'epoch': self.n_epoch, 'state': model.state_dict()}

    def _get_optimizer(self, model):
        """
        Get the optimizer from string self.optimizer
        Args:
            model (torch.nn.Module): the model to be trained

        Returns: a torch.optim.Optimizer object parameterized with model parameters

        """
        assert hasattr(torch.optim, self.optimizer), "The optimization method is not a torch.optim object"
        optimizer = getattr(torch.optim, self.optimizer)(model.parameters(), lr=self.learning_rate)

        return optimizer

    def _get_data_loader(self, path_to_data_file, path_to_images_per_label):
        """

        Args:
            path_to_data_file (str): path to file containing paths to images
            path_to_images_per_label (str): path to pickle file containing the dictionary of images per label

        Returns:
            torch.utils.data.DataLoader: samples data in the shape of a detection task
        """
        data_manager = DetectionSetDataManager(self.n_way, self.n_shot, self.n_query, self.n_episode, self.image_size)

        return data_manager.get_data_loader(path_to_data_file, path_to_images_per_label)

    def _get_model(self):
        """

        Returns:
            YOLOMAML: meta-model
        """

        base_model = Darknet(self.model_config, self.image_size, self.pretrained_weights)

        model = YOLOMAML(
            base_model,
            self.n_way,
            self.n_shot,
            self.n_query,
            self.image_size,
            approx=self.approx,
            task_update_num=self.task_update_num,
            train_lr=self.learning_rate,
            objectness_threshold=self.objectness_threshold,
            nms_threshold=self.nms_threshold,
            iou_threshold=self.iou_threshold,
            device=self.device,
        )

        return model

    def plot_tensorboard(self, loss_dict, epoch):
        """
        Writes into summary the values present in loss_dict
        Args:
            loss_dict (dict): contains the different parts of the average loss on one epoch. Each key describes
            a part of the loss (ex: query_classification_loss) and each value is a 0-dim tensor. This dictionary is
            required to contain the keys 'support_total_loss' and 'query_total_loss' which contains respectively the
            total loss on the support set, and the total meta-loss on the query set
            epoch (int): global step value in the summary

        Returns:

        """
        for key, value in loss_dict.items():
            self.writer.add_scalar(key, value, epoch)

        return
def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1000,
        metavar="N",
        help="input batch size for testing (default: 1000)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=14,
        metavar="N",
        help="number of epochs to train (default: 14)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1.0,
        metavar="LR",
        help="learning rate (default: 1.0)",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.7,
        metavar="M",
        help="Learning rate step gamma (default: 0.7)",
    )
    parser.add_argument("--no-cuda",
                        action="store_true",
                        default=False,
                        help="disables CUDA training")
    parser.add_argument("--seed",
                        type=int,
                        default=1,
                        metavar="S",
                        help="random seed (default: 1)")
    parser.add_argument(
        "--log-interval",
        type=int,
        default=250,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument(
        "--save-model",
        action="store_true",
        default=False,
        help="For Saving the current Model",
    )
    parser.add_argument(
        "--compression",
        type=str,
        default="Wavelet",
        help="Choose the compression mode, None, Wavelet, Fastfood",
    )
    parser.add_argument(
        "--wave_loss_weight",
        type=float,
        default=1.0,
        help="Weight term of the wavelet loss",
    )
    parser.add_argument(
        "--wave_dropout",
        type=float,
        default=0.5,
        help="Wavelet layer dropout probability.",
    )

    args = parser.parse_args()
    print(args)
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "../data",
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]),
    ),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        "../data",
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]),
    ),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    if args.compression == "Wavelet":
        CustomWavelet = collections.namedtuple(
            "Wavelet", ["dec_lo", "dec_hi", "rec_lo", "rec_hi", "name"])
        # init_wavelet = ProductFilter(
        #     dec_lo=[0, 0, 0.7071067811865476, 0.7071067811865476, 0, 0],
        #     dec_hi=[0, 0, -0.7071067811865476, 0.7071067811865476, 0, 0],
        #     rec_lo=[0, 0, 0.7071067811865476, 0.7071067811865476, 0, 0],
        #     rec_hi=[0, 0, 0.7071067811865476, -0.7071067811865476, 0, 0],
        #     )

        # random init
        init_wavelet = ProductFilter(
            torch.rand(size=[6], requires_grad=True) / 2 - 0.25,
            torch.rand(size=[6], requires_grad=True) / 2 - 0.25,
            torch.rand(size=[6], requires_grad=True) / 2 - 0.25,
            torch.rand(size=[6], requires_grad=True) / 2 - 0.25,
        )

    else:
        init_wavelet = None

    model = Net(
        compression=args.compression,
        wavelet=init_wavelet,
        wave_dropout=args.wave_dropout,
    ).to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
    writer = SummaryWriter()

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    test_wvl_lst = []
    test_acc_lst = []
    test_wvl_loss, test_acc = test(args, model, device, test_loader, writer, 0)
    test_wvl_lst.append(test_wvl_loss.item())
    test_acc_lst.append(test_acc)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test_wvl_loss, test_acc = test(args, model, device, test_loader,
                                       writer, epoch)
        test_wvl_lst.append(test_wvl_loss.item())
        test_acc_lst.append(test_acc)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")

    print(compute_parameter_total(model))

    # plt.semilogy(test_wvl_lst)
    # plt.semilogy(test_acc_lst)
    # plt.legend(['wavlet loss', 'accuracy'])
    # plt.show()

    plt.plot(model.fc1.wavelet.dec_lo.detach().cpu().numpy(), "-*")
    plt.plot(model.fc1.wavelet.dec_hi.detach().cpu().numpy(), "-*")
    plt.plot(model.fc1.wavelet.rec_lo.detach().cpu().numpy(), "-*")
    plt.plot(model.fc1.wavelet.rec_hi.detach().cpu().numpy(), "-*")
    plt.legend(["H_0", "H_1", "F_0", "F_1"])
    plt.show()
    print("done")
示例#22
0
    def __init__(
            self,
            dataset_config='yolov3/config/black.data',
            model_config='yolov3/config/yolov3.cfg',
            pretrained_weights=None,
            n_way=5,
            n_shot=5,
            n_query=16,
            optimizer='Adam',
            learning_rate=0.001,
            approx=True,
            task_update_num=3,
            print_freq=100,
            validation_freq=1000,
            n_epoch=100,
            n_episode=100,
            objectness_threshold=0.8,
            nms_threshold=0.4,
            iou_threshold=0.2,
            image_size=416,
            random_seed=None,
            output_dir=configs.save_dir,
    ):
        """
        Args:
            dataset_config (str): path to data config file
            model_config (str): path to model definition file
            pretrained_weights (str): path to a file containing pretrained weights for the model
            n_way (int): number of labels in a detection task
            n_shot (int): number of support data in each class in an episode
            n_query (int): number of query data in each class in an episode
            optimizer (str): must be a valid class of torch.optim (Adam, SGD, ...)
            learning_rate (float): learning rate fed to the optimizer
            approx (bool): whether to use an approximation of the meta-backpropagation
            task_update_num (int): number of updates inside each episode
            print_freq (int): inside an epoch, print status update every print_freq episodes
            validation_freq (int): inside an epoch, frequency with which we evaluate the model on the validation set
            n_epoch (int): number of meta-training epochs
            n_episode (int): number of episodes per epoch during meta-training
            objectness_threshold (float): at evaluation time, only keep boxes with objectness above this threshold
            nms_threshold (float): threshold for non maximum suppression, at evaluation time
            iou_threshold (float): threshold for intersection over union
            image_size (int): size of images (square)
            random_seed (int): seed for random instantiations ; if none is provided, a seed is randomly defined
            output_dir (str): path to experiments output directory
        """

        self.dataset_config = dataset_config
        self.model_config = model_config
        self.pretrained_weights = pretrained_weights
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.approx = approx
        self.task_update_num = task_update_num
        self.print_freq = print_freq
        self.validation_freq = validation_freq
        self.n_epoch = n_epoch
        self.n_episode = n_episode
        self.objectness_threshold = objectness_threshold
        self.nms_threshold = nms_threshold
        self.iou_threshold = iou_threshold
        self.image_size = image_size
        self.random_seed = random_seed
        self.checkpoint_dir = output_dir

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.writer = SummaryWriter(log_dir=output_dir)
示例#23
0
class YOLOTraining():
    """
    This step handles the training of the algorithm on the base dataset
    """
    def __init__(
        self,
        dataset_config='yolov3/config/black.data',
        model_config='yolov3/config/yolov3.cfg',
        pretrained_weights=None,
        optimizer='Adam',
        learning_rate=0.001,
        multiscale_training=True,
        batch_size=32,
        n_cpu=8,
        gradient_accumulation=10,
        print_freq=1,
        validation_freq=5,
        n_epoch=100,
        objectness_threshold=0.8,
        nms_threshold=0.4,
        iou_threshold=0.2,
        image_size=416,
        random_seed=None,
        output_dir=configs.save_dir,
    ):
        """
        Args:
            dataset_config (str): path to data config file
            model_config (str): path to model definition file
            pretrained_weights (str): path to a file containing pretrained weights for the model
            optimizer (str): must be a valid class of torch.optim (Adam, SGD, ...)
            learning_rate (float): learning rate fed to the optimizer
            multiscale_training (bool): whether to sample batches with different image sizes
            batch_size (int): size of a training batch
            n_cpu (int): number of workers for the computation of the dataloader
            gradient_accumulation (int): number of gradients from batches to accumulate before a gradient descent
            print_freq (int): inside an epoch, print status update every print_freq episodes
            validation_freq (int): inside an epoch, frequency with which we evaluate the model on the validation set
            n_epoch (int): number of meta-training epochs
            objectness_threshold (float): at evaluation time, only keep boxes with objectness above this threshold
            nms_threshold (float): threshold for non maximum suppression, at evaluation time
            iou_threshold (float): threshold for intersection over union
            image_size (int): size of images (square)
            random_seed (int): seed for random instantiations ; if none is provided, a seed is randomly defined
            output_dir (str): path to experiments output directory
        """

        self.dataset_config = dataset_config
        self.model_config = model_config
        self.pretrained_weights = pretrained_weights
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.multiscale_training = multiscale_training
        self.batch_size = batch_size
        self.n_cpu = n_cpu
        self.gradient_accumulation = gradient_accumulation
        self.print_freq = print_freq
        self.validation_freq = validation_freq
        self.n_epoch = n_epoch
        self.objectness_threshold = objectness_threshold
        self.nms_threshold = nms_threshold
        self.iou_threshold = iou_threshold
        self.image_size = image_size
        self.random_seed = random_seed
        self.checkpoint_dir = output_dir

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.writer = SummaryWriter(log_dir=output_dir)

    def apply(self):
        """
        Execute the YOLOMAMLTraining step
        Returns:
            dict: a dictionary containing the whole state of the model that gave the higher validation accuracy

        """
        set_and_print_random_seed(self.random_seed, True, self.checkpoint_dir)

        data_config = parse_data_config(self.dataset_config)
        train_path = data_config["train"]
        valid_path = data_config.get("valid", None)

        train_loader = self._get_data_loader(train_path)
        val_loader = self._get_data_loader(valid_path)

        model = self._get_model()

        return self._train(train_loader, val_loader, model)

    def dump_output(self, _, output_folder, output_name, **__):
        pass

    def _train(self, train_loader, val_loader, model):
        """
        Trains the model on the training set
        Args:
            train_loader (torch.utils.data.DataLoader): data loader for training set
            val_loader (torch.utils.data.DataLoader): data loader for validation set
            model (Darknet): neural network model to train

        Returns:
            dict: a dictionary containing the whole state of the model that gave the higher validation accuracy

        """
        optimizer = self._get_optimizer(model)
        optimizer.zero_grad()

        for epoch in range(self.n_epoch):
            loss_dict = {}

            model.train()
            for batch_index, (_, images, targets) in enumerate(train_loader):
                batch_loss_dict, _ = model.forward(images.to(self.device),
                                                   targets.to(self.device))
                loss = batch_loss_dict['total_loss']
                loss.backward()
                loss_dict = include_episode_loss_dict(loss_dict,
                                                      batch_loss_dict,
                                                      len(train_loader))

                if batch_index % self.gradient_accumulation == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            self.plot_tensorboard(loss_dict, epoch)

            if epoch % self.print_freq == 0:
                print('Epoch {epoch}/{n_epochs} | Loss {loss}'.format(
                    epoch=epoch,
                    n_epochs=self.n_epoch,
                    loss=loss_dict['total_loss'],
                ))

        self.writer.close()

        model.save_darknet_weights(
            os.path.join(self.checkpoint_dir, 'final.weights'))

        return {'epoch': self.n_epoch, 'state': model.state_dict()}

    def _get_optimizer(self, model):
        """
        Get the optimizer from string self.optimizer
        Args:
            model (torch.nn.Module): the model to be trained

        Returns: a torch.optim.Optimizer object parameterized with model parameters

        """
        assert hasattr(torch.optim, self.optimizer
                       ), "The optimization method is not a torch.optim object"
        optimizer = getattr(torch.optim, self.optimizer)(model.parameters(),
                                                         lr=self.learning_rate)

        return optimizer

    def _get_data_loader(self, path_to_data_file):
        """

        Args:
            path_to_data_file (str): path to file containing paths to images

        Returns:
            torch.utils.data.DataLoader: samples data in the shape of batches
        """
        dataset = ListDataset(path_to_data_file,
                              augment=True,
                              multiscale=self.multiscale_training)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_cpu,
            pin_memory=True,
            collate_fn=dataset.collate_fn,
        )

        return dataloader

    def _get_model(self):
        """

        Returns:
            Darknet: YOLO model
        """

        model = Darknet(self.model_config, self.image_size,
                        self.pretrained_weights).to(self.device)

        return model

    def plot_tensorboard(self, loss_dict, epoch):
        """
        Writes into summary the values present in loss_dict
        Args:
            loss_dict (dict): contains the different parts of the average loss on one epoch. Each key describes
            a part of the loss (ex: query_classification_loss) and each value is a 0-dim tensor. This dictionary is
            required to contain the keys 'support_total_loss' and 'query_total_loss' which contains respectively the
            total loss on the support set, and the total meta-loss on the query set
            epoch (int): global step value in the summary

        Returns:

        """
        for key, value in loss_dict.items():
            self.writer.add_scalar(key, value, epoch)

        return
示例#24
0
def main(
    architecture,
    batch_size,
    length_scale,
    centroid_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
    output_dir,
):
    writer = SummaryWriter(log_dir=f"runs/{output_dir}")

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform
                                 )  # Test time preprocessing for validation

    if architecture == "WRN":
        model_output_size = 640
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = WideResNet()
    elif architecture == "ResNet18":
        model_output_size = 512
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet18()
    elif architecture == "ResNet50":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet50()
    elif architecture == "ResNet110":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet110()
    elif architecture == "DenseNet121":
        model_output_size = 1024
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = densenet121()

        # Adapted resnet from:
        # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
        feature_extractor.conv1 = torch.nn.Conv2d(3,
                                                  64,
                                                  kernel_size=3,
                                                  stride=1,
                                                  padding=1,
                                                  bias=False)
        feature_extractor.maxpool = torch.nn.Identity()
        feature_extractor.fc = torch.nn.Identity()

    if centroid_size is None:
        centroid_size = model_output_size

    model = ResNet_DUQ(
        feature_extractor,
        num_classes,
        centroid_size,
        model_output_size,
        length_scale,
        gamma,
    )
    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=milestones,
                                                     gamma=0.2)

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        y = F.one_hot(y, num_classes).float()

        loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

        if l_gradient_penalty > 0:
            gp = calc_gradient_penalty(x, y_pred)
            loss += l_gradient_penalty * gp

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        return {"x": x, "y": y, "y_pred": y_pred}

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "loss")

    metric = Accuracy(output_transform=lambda out: (out["y_pred"], out["y"]))
    metric.attach(evaluator, "accuracy")

    def bce_output_transform(out):
        return (out["y_pred"], F.one_hot(out["y"], num_classes).float())

    metric = Loss(F.binary_cross_entropy,
                  output_transform=bce_output_transform)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty,
                  output_transform=lambda out: (out["x"], out["y_pred"]))
    metric.attach(evaluator, "gradient_penalty")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              **kwargs)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        loss = metrics["loss"]

        print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f}")

        writer.add_scalar("Loss/train", loss, trainer.state.epoch)

        if trainer.state.epoch > (epochs - 5):
            accuracy, auroc = get_cifar_svhn_ood(model)
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            writer.add_scalar("OoD/test_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch)

            accuracy, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            writer.add_scalar("OoD/val_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc_classification", auroc,
                              trainer.state.epoch)

        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        bce = metrics["bce"]
        GP = metrics["gradient_penalty"]
        loss = bce + l_gradient_penalty * GP

        print((f"Valid - Epoch: {trainer.state.epoch} "
               f"Acc: {acc:.4f} "
               f"Loss: {loss:.2f} "
               f"BCE: {bce:.2f} "
               f"GP: {GP:.2f} "))

        writer.add_scalar("Loss/valid", loss, trainer.state.epoch)
        writer.add_scalar("BCE/valid", bce, trainer.state.epoch)
        writer.add_scalar("GP/valid", GP, trainer.state.epoch)
        writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch)

        scheduler.step()

    trainer.run(train_loader, max_epochs=epochs)
    evaluator.run(test_loader)
    acc = evaluator.state.metrics["accuracy"]

    print(f"Test - Accuracy {acc:.4f}")

    torch.save(model.state_dict(), f"runs/{output_dir}/model.pt")
    writer.close()
示例#25
0
class ModelTrainerRSS:
    """
    Model trainer that has RSS outputs. The inputs maybe complex or magnitude images.
    """
    def __init__(self, args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform,
                 output_train_transform, output_val_transform, losses, scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_train_transform, nn.Module) and isinstance(output_val_transform, nn.Module), \
            '`output_train_transform` and `output_val_transform` must be Pytorch Modules.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with a tuple as its output.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=True)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader

        concat_dataset = ConcatDataset([train_loader.dataset, val_loader.dataset])
        self.concat_loader = DataLoader(dataset=concat_dataset, batch_size=args.batch_size, shuffle=True,
                                        num_workers=args.num_workers, collate_fn=temp_collate_fn, pin_memory=False)

        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_train_transform = output_train_transform
        self.output_val_transform = output_val_transform
        self.losses = losses
        self.scheduler = scheduler
        self.writer = SummaryWriter(str(args.log_path))

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.use_slice_metrics = args.use_slice_metrics

        # This part should get SSIM, not 1 - SSIM.
        self.ssim = SSIM(filter_size=7).to(device=args.device)  # Needed to cache the kernel.

        # Logging all components of the Model Trainer.
        # Train and Val input and output transforms are assumed to use the same input transform class.
        self.logger.info(f'''
        Summary of Model Trainer Components:
        Model: {get_class_name(model)}.
        Optimizer: {get_class_name(optimizer)}.
        Input Transforms: {get_class_name(input_val_transform)}.
        Output Transform: {get_class_name(output_val_transform)}.
        RSS Image Domain Loss: {get_class_name(losses['rss_loss'])}.
        Learning-Rate Scheduler: {get_class_name(scheduler)}.
        ''')  # This part has parts different for IMG and CMG losses!!

    def train_model(self):
        tic_tic = time()
        self.logger.info('Beginning Combined Training Loop.')

        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing of epochs.
            tic = time()  # Training
            train_epoch_loss, train_epoch_metrics = self._train_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, train_epoch_loss, train_epoch_metrics, elapsed_secs=toc, training=True)

            tic = time()  # Validation
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, val_epoch_loss, val_epoch_metrics, elapsed_secs=toc, training=False)

            self.manager.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(f'Finishing Training Loop. Total elapsed time: '
                         f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.')

    def train_model_concat(self):
        tic_tic = time()
        self.logger.info('Beginning Concatenated Training Loop.')
        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing of epochs.
            tic = time()  # Concatenated Training
            train_epoch_loss, train_epoch_metrics = self._train_concat_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, train_epoch_loss, train_epoch_metrics, elapsed_secs=toc, training=True)

            # tic = time()  # Validation
            # val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            # toc = int(time() - tic)
            # self._log_epoch_outputs(epoch, val_epoch_loss, val_epoch_metrics, elapsed_secs=toc, training=False)

            self.manager.save(metric=train_epoch_loss, verbose=True)

            for idx, group in enumerate(self.optimizer.param_groups, start=1):
                self.writer.add_scalar(f'learning_rate_{idx}', group['lr'], global_step=epoch)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=train_epoch_loss)
                else:
                    self.scheduler.step()

        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(f'Finishing Training Loop. Total elapsed time: '
                         f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.')

    def _train_concat_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list()  # Appending values to list due to numerical underflow and NaN values.
        epoch_metrics = defaultdict(list)

        data_loader = enumerate(self.concat_loader, start=1)  # Only the data-loader is different from _train_epoch.
        if not self.verbose:  # tqdm has to be on the outermost iterator to function properly.
            # Known but minor bug: The tqdm total is accurate only when batch size is 1.
            data_loader = tqdm(data_loader, total=len(self.concat_loader.dataset))

        for step, data in data_loader:
            # Data pre-processing is expected to have gradient calculations removed inside already.
            inputs, targets, extra_params = self.input_train_transform(*data)

            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._train_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())  # Perhaps not elegant, but underflow makes this necessary.

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                if self.use_slice_metrics:
                    slice_metrics = self._get_slice_metrics(recons, targets, extra_params)
                    step_metrics.update(slice_metrics)

                [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

                if self.verbose:
                    self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True)

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=True)

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list()  # Appending values to list due to numerical underflow and NaN values.
        epoch_metrics = defaultdict(list)

        data_loader = enumerate(self.train_loader, start=1)
        if not self.verbose:  # tqdm has to be on the outermost iterator to function properly.
            # Known but minor bug: The tqdm total is accurate only when batch size is 1.
            data_loader = tqdm(data_loader, total=len(self.train_loader.dataset))

        for step, data in data_loader:
            # Data pre-processing is expected to have gradient calculations removed inside already.
            inputs, targets, extra_params = self.input_train_transform(*data)

            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._train_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())  # Perhaps not elegant, but underflow makes this necessary.

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                if self.use_slice_metrics:
                    slice_metrics = self._get_slice_metrics(recons, targets, extra_params)
                    step_metrics.update(slice_metrics)

                [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

                if self.verbose:
                    self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True)

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=True)

    def _train_step(self, inputs, targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        recons = self.output_train_transform(outputs, targets, extra_params)
        step_loss, step_metrics = self._step(recons, targets, extra_params)
        step_loss.backward()
        self.optimizer.step()
        return recons, step_loss, step_metrics

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            recons, step_loss, step_metrics = self._val_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(recons, targets, extra_params)
                step_metrics.update(slice_metrics)

            [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

            if self.verbose:
                self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False)

            # Visualize images on TensorBoard.
            self._visualize_images(recons, targets, extra_params, epoch, step, training=False)

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False)

    def _val_step(self, inputs, targets, extra_params):
        outputs = self.model(inputs)
        recons = self.output_val_transform(outputs, targets, extra_params)
        step_loss, step_metrics = self._step(recons, targets, extra_params)
        return recons, step_loss, step_metrics

    def _step(self, recons, targets, extra_params):
        step_loss = self.losses['rss_loss'](recons['rss_recons'], targets['rss_targets'])

        # If step_loss is a tuple, it is expected to contain all its component losses as a dict in its second element.
        rss_metrics = dict()
        step_metrics = dict()
        if isinstance(step_loss, tuple):
            step_loss, rss_metrics = step_loss

        if 'acceleration' in extra_params:  # Different metrics for different accelerations.
            acc = extra_params['acceleration']
            if rss_metrics:  # This has to be checked before anything is added to step_metrics.
                for key, value in rss_metrics.items():
                    step_metrics[f'acc_{acc}_{key}'] = value
            step_metrics[f'acc_{acc}_loss'] = step_loss
            step_metrics.update(rss_metrics)

        return step_loss, step_metrics

    def _visualize_images(self, recons, targets, extra_params, epoch, step, training=False):
        mode = 'Training' if training else 'Validation'

        # This numbering scheme seems to have issues for certain numbers.
        # Please check cases when there is no remainder.
        if self.display_interval and (step % self.display_interval == 0):

            acc = extra_params['acceleration']
            kwargs = dict(global_step=epoch, dataformats='HW')

            # Adding RSS images of reconstructions and targets.
            recon_rss = standardize_image(recons['rss_recons'])
            delta_rss = standardize_image(targets['rss_targets'] - recons['rss_recons'])
            self.writer.add_image(f'{mode} RSS Recons/{acc}/{step}', recon_rss, **kwargs)
            self.writer.add_image(f'{mode} RSS Deltas/{acc}/{step}', delta_rss, **kwargs)

            if epoch == 1:  # Maybe add input images too later on.
                # Not actually the input but the RSS of the input images.
                target_rss = standardize_image(targets['rss_targets'])
                self.writer.add_image(f'{mode} RSS Targets/{acc}/{step}', target_rss, **kwargs)

                if 'rss_inputs' in targets:
                    input_rss = standardize_image(targets['rss_inputs'])
                    self.writer.add_image(f'{mode} RSS Inputs/{acc}/{step}', input_rss, **kwargs)

    def _get_slice_metrics(self, recons, targets, extra_params):
        rss_metrics = dict()
        rss_recons = recons['rss_recons'].detach()
        rss_targets = targets['rss_targets'].detach()

        rss_ssim = self.ssim(rss_recons, rss_targets)
        rss_psnr = psnr(rss_recons, rss_targets)
        rss_nmse = nmse(rss_recons, rss_targets)

        rss_metrics['rss/ssim'] = rss_ssim
        rss_metrics['rss/psnr'] = rss_psnr
        rss_metrics['rss/nmse'] = rss_nmse

        # Additional metrics for separating between acceleration factors.
        acc = extra_params["acceleration"]
        rss_metrics[f'rss_acc_{acc}/ssim'] = rss_ssim
        rss_metrics[f'rss_acc_{acc}/psnr'] = rss_psnr
        rss_metrics[f'rss_acc_{acc}/nmse'] = rss_nmse

        return rss_metrics

    def _get_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(self.val_loader.dataset)

        # Checking for nan values.
        epoch_loss = torch.stack(epoch_loss)
        is_finite = torch.isfinite(epoch_loss)
        num_nans = (is_finite.size(0) - is_finite.sum()).item()

        if num_nans > 0:
            self.logger.warning(f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices.'
                                f'Turning on anomaly detection.')
            # Turn on anomaly detection for finding where the nan values are.
            torch.autograd.set_detect_anomaly(True)
            epoch_loss = torch.mean(epoch_loss[is_finite]).item()
        else:
            epoch_loss = torch.mean(epoch_loss).item()

        for key, value in epoch_metrics.items():
            epoch_metric = torch.stack(value)
            is_finite = torch.isfinite(epoch_metric)
            num_nans = (is_finite.size(0) - is_finite.sum()).item()

            if num_nans > 0:
                self.logger.warning(f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices.'
                                    f'Turning on anomaly detection.')
                epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item()
            else:
                epoch_metrics[key] = torch.mean(epoch_metric).item()

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}')
        for key, value in step_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}')

    def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, '
                         f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode} epoch_loss', scalar_value=epoch_loss, global_step=epoch)

        for key, value in epoch_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}')
            # Very important whether it is mode_~~ or mode/~~.
            if 'loss' in key:
                self.writer.add_scalar(f'{mode}/epoch_{key}', scalar_value=value, global_step=epoch)
            else:
                self.writer.add_scalar(f'{mode}_epoch_{key}', scalar_value=value, global_step=epoch)

        if not training:  # Record learning rate.
            for idx, group in enumerate(self.optimizer.param_groups, start=1):
                self.writer.add_scalar(f'learning_rate_{idx}', group['lr'], global_step=epoch)
parser.add_argument('--nesterov',
                    help='Use nesterov SGD',
                    action='store_true',
                    default=False)

parser.set_defaults(bottleneck=True)
parser.set_defaults(augment=True)

best_prec1 = 0
args = parser.parse_args()
print(args)

if args.tensorboard:
    # configure("runs/%s"%(args.name))
    writer = SummaryWriter(comment='_' + args.data_set + '_' +
                           args.pooling_type + '_lr_' + str(args.lr) + '_m_' +
                           str(args.momentum))


def main():
    global args, best_prec1

    torch.manual_seed(args.seed)
    # Data loading code
    # to_tensor transform includes division by 255.
    # see https://pytorch.org/docs/stable/torchvision/
    # transforms.html#torchvision.transforms.ToTensor
    if args.data_set == 'cifar10':

        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
示例#27
0
文件: transformer.py 项目: OxyMal/NLP
        outputs = outputs[:, 1:]

        hyp += get_text_from_tensor(outputs, TRG)
        ref += get_text_from_tensor(trg, TRG)

    # expand dim of reference list
    # sys = ['translation_1', 'translation_2']
    # ref = [['truth_1', 'truth_2'], ['another truth_1', 'another truth_2']]
    ref = [ref]
    return sacrebleu.corpus_bleu(hyp, ref, force=True).score


def inference(model, source_sentence):
    source_sentence_tokens = SRC.preprocess(source_sentence)
    src = SRC.process([source_sentence_tokens]).T
    outputs = search(model, src)
    print(get_text_from_tensor(outputs, TRG))


if __name__ == "__main__":
    MODEL_PATH = "transformer_model.pt"
    if not os.path.exists(MODEL_PATH):
        writer = SummaryWriter()
        transformer = Transformer(len(SRC.vocab), len(TRG.vocab)).to(device)
        train(transformer, SRC, TRG, MODEL_PATH)
    else:
        transformer = torch.load(MODEL_PATH, map_location=device)
        inference(transformer,
                  "Eine Frau mit blonden Haaren trinkt aus einem Glas")
        print(evaluate_bleu(transformer, test_iter))
示例#28
0
class TensorboardSummaryHook:
    """
    Logging object allowing Tensorboard summaries to be automatically exported to the tensorboard. Much of its
    functionality is automated. This means that the hook will export as much information as possible to the
    tensorboard.

    Losses, Metrics, Inputs and Outputs are all interpreted and exported according to their dimensionality. Vectors
    results in mean and standard deviation estimates as well as histograms; Pictures results in image summaries and
    histograms; etc.

    There is also the possibily of comparing inputs and outputs pair. This needs to be specified during object
    instantiation.

    Once the user instantiates this object, the workflow corresponding to the ID passes as argument will be
    tracked and the results of the workflow will be exported to the tensorboard.

    .. code-block:: python

            from eisen.utils.logging import TensorboardSummaryHook

            workflow = # Eg. An instance of Training workflow

            logger = TensorboardSummaryHook(workflow.id, 'Training', '/artifacts/dir')
    """
    def __init__(self,
                 workflow_id,
                 phase,
                 artifacts_dir,
                 comparison_pairs=None):
        """
        This method instantiates an object of type TensorboardSummaryHook. The signature of this method is similar to
        that of every other hook. There is one additional parameter called `comparison_pairs` which is meant to
        hold a list of lists each containing a pair of input/output names that share the same dimensionality and can be
        compared to each other.

        A typical use of `comparison_pairs` is when users want to plot a pr_curve or a confusion matrix by comparing
        some input with some output. Eg. by comparing the labels with the predictions.

        .. code-block:: python

            from eisen.utils.logging import TensorboardSummaryHook

            workflow = # Eg. An instance of Training workflow

            logger = TensorboardSummaryHook(
                workflow_id=workflow.id,
                phase='Training',
                artifacts_dir='/artifacts/dir'
                comparison_pairs=[['labels', 'predictions']]
            )

        :param workflow_id: string containing the workflow id of the workflow being monitored (workflow_instance.id)
        :type workflow_id: UUID
        :param phase: string containing the name of the phase (training, testing, ...) of the workflow monitored
        :type phase: str
        :param artifacts_dir: whether the history of all models that were at a certain point the best should be saved
        :type artifacts_dir: bool
        :param comparison_pairs: list of lists of pairs, which are names of inputs and outputs to be compared directly
        :type comparison_pairs: list of lists of strings

        <json>
        [
            {"name": "comparison_pairs", "type": "list:list:string", "value": ""}
        ]
        </json>
        """
        self.workflow_id = workflow_id
        self.phase = phase

        self.comparison_pairs = comparison_pairs

        if not os.path.exists(artifacts_dir):
            raise ValueError(
                'The directory specified to save artifacts does not exist!')

        dispatcher.connect(self.end_epoch,
                           signal=EISEN_END_EPOCH_EVENT,
                           sender=workflow_id)

        self.artifacts_dir = os.path.join(artifacts_dir, 'summaries', phase)

        if not os.path.exists(self.artifacts_dir):
            os.makedirs(self.artifacts_dir)

        self.writer = SummaryWriter(log_dir=self.artifacts_dir)

    def end_epoch(self, message):
        epoch = message['epoch']

        # if epoch == 0:
        #     self.writer.add_graph(message['model'], ...)

        for typ in ['losses', 'metrics']:
            for dct in message[typ]:
                for key in dct.keys():
                    self.write_vector(typ + '/{}'.format(key), dct[key], epoch)

        for typ in ['inputs', 'outputs']:
            for key in message[typ].keys():
                if message[typ][key].ndim == 5:
                    # Volumetric image (N, C, W, H, D)
                    self.write_volumetric_image(typ + '/{}'.format(key),
                                                message[typ][key], epoch)

                if message[typ][key].ndim == 4:
                    self.write_image(typ + '/{}'.format(key),
                                     message[typ][key], epoch)

                if message[typ][key].ndim == 3:
                    self.write_embedding(typ + '/{}'.format(key),
                                         message[typ][key], epoch)

                if message[typ][key].ndim == 2:
                    self.write_class_probabilities(typ + '/{}'.format(key),
                                                   message[typ][key], epoch)

                if message[typ][key].ndim == 1:
                    self.write_vector(typ + '/{}'.format(key),
                                      message[typ][key], epoch)

                if message[typ][key].ndim == 0:
                    self.write_scalar(typ + '/{}'.format(key),
                                      message[typ][key], epoch)

        if self.comparison_pairs:
            for inp, out in self.comparison_pairs:
                assert message['inputs'][inp].ndim == message['outputs'][
                    out].ndim

                if message['inputs'][inp].ndim == 1:
                    # in case of binary classification >> PR curve
                    if np.max(message['inputs'][inp]) <= 1 and np.max(
                            message['outputs'][out]) <= 1:
                        self.write_pr_curve(
                            '{}_Vs_{}/pr_curve'.format(inp, out),
                            message['inputs'][inp], message['outputs'][out],
                            epoch)

                    # in any case for classification >> Confusion Matrix
                    self.write_confusion_matrix(
                        '{}_Vs_{}/confusion_matrix'.format(inp, out),
                        message['inputs'][inp], message['outputs'][out], epoch)

    def write_volumetric_image(self, name, value, global_step):
        value = np.transpose(value, [0, 2, 1, 3, 4])

        if value.shape[2] != 3 and value.shape[2] != 1:
            value = np.sum(value, axis=2, keepdims=True)

        torch_value = torch.tensor(value).float()

        self.writer.add_video(name,
                              torch_value,
                              fps=10,
                              global_step=global_step)
        self.writer.add_scalar(name + '/mean',
                               np.mean(value),
                               global_step=global_step)
        self.writer.add_scalar(name + '/std',
                               np.std(value),
                               global_step=global_step)
        self.writer.add_histogram(name + '/histogram',
                                  value.flatten(),
                                  global_step=global_step)

    def write_image(self, name, value, global_step):
        self.writer.add_scalar(name + '/mean',
                               np.mean(value),
                               global_step=global_step)
        self.writer.add_scalar(name + '/std',
                               np.std(value),
                               global_step=global_step)
        self.writer.add_histogram(name + '/histogram',
                                  value.flatten(),
                                  global_step=global_step)
        self.writer.add_images(name,
                               value,
                               global_step=global_step,
                               dataformats='NCHW')

    def write_embedding(self, name, value, global_step):
        pass

    def write_pr_curve(self, name, labels, predictions, global_step):
        self.writer.add_pr_curve(name + '/pr_curve', labels, predictions,
                                 global_step)

    def write_confusion_matrix(self, name, labels, predictions, global_step):
        cnf_matrix = confusion_matrix(labels, predictions)
        image = plot_confusion_matrix(cnf_matrix,
                                      range(np.max(labels) + 1),
                                      normalize=True,
                                      title=name)[:, :, 0:3]
        self.writer.add_image(name,
                              image.astype(float) / 255.0,
                              global_step=global_step,
                              dataformats='HWC')

    def write_class_probabilities(self, name, value, global_step):
        self.writer.add_image(name,
                              value,
                              global_step=global_step,
                              dataformats='HW')
        self.writer.add_histogram(name + '/distribution',
                                  np.argmax(value),
                                  global_step=global_step)

    def write_vector(self, name, value, global_step):
        self.writer.add_histogram(name, value, global_step=global_step)
        self.writer.add_scalar(name + '/mean',
                               np.mean(value),
                               global_step=global_step)
        self.writer.add_scalar(name + '/std',
                               np.std(value),
                               global_step=global_step)

    def write_scalar(self, name, value, global_step):
        self.writer.add_scalar(name, value, global_step=global_step)
class ModelTrainerK2C:

    def __init__(self, args, model, optimizer, train_loader, val_loader,
                 input_train_transform, input_val_transform, output_transform, losses, scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_transform, nn.Module), '`output_transform` must be a Pytorch Module.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with multiple outputs.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_transform = output_transform
        self.losses = losses
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.use_slice_metrics = args.use_slice_metrics
        self.writer = SummaryWriter(str(args.log_path))

    def train_model(self):
        tic_tic = time()
        self.logger.info('Beginning Training Loop.')
        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing of epochs.
            tic = time()  # Training
            train_epoch_loss, train_epoch_metrics = self._train_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, train_epoch_loss, train_epoch_metrics, elapsed_secs=toc, training=True)

            tic = time()  # Validation
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, val_epoch_loss, val_epoch_metrics, elapsed_secs=toc, training=False)

            self.manager.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(f'Finishing Training Loop. Total elapsed time: '
                         f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.')

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list()  # Appending values to list due to numerical underflow.
        epoch_metrics = defaultdict(list)

        data_loader = enumerate(self.train_loader, start=1)
        if not self.verbose:  # tqdm has to be on the outermost iterator to function properly.
            data_loader = tqdm(data_loader, total=len(self.train_loader.dataset))

        for step, data in data_loader:
            # Data pre-processing is expected to have gradient calculations removed already.
            inputs, targets, extra_params = self.input_train_transform(*data)

            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._train_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())  # Perhaps not elegant, but underflow makes this necessary.

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                if self.use_slice_metrics:
                    slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                    step_metrics.update(slice_metrics)

                [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

                if self.verbose:
                    self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True)

        # Converted to scalar and dict with scalar forms.
        return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=True)

    def _train_step(self, inputs, targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        recons = self.output_transform(outputs, targets, extra_params)
        step_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets'])
        step_loss.backward()
        self.optimizer.step()
        step_metrics = dict()
        return recons, step_loss, step_metrics

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            recons, step_loss, step_metrics = self._val_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                step_metrics.update(slice_metrics)

            [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

            if self.verbose:
                self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False)

            # This numbering scheme seems to have issues for certain numbers.
            # Please check cases when there is no remainder.
            if self.display_interval and (step % self.display_interval == 0):
                # Change image display function later.
                img_recon_grid, img_target_grid, img_delta_grid = \
                    make_grid_triplet(recons['img_recons'], targets['img_targets'])
                kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor)
                kspace_target_grid = make_k_grid(targets['kspace_targets'], self.smoothing_factor)

                self.writer.add_image(f'k-space_Recons/{step}', kspace_recon_grid, epoch, dataformats='HW')

                self.writer.add_image(f'Image_Recons/{step}', img_recon_grid, epoch, dataformats='HW')

                self.writer.add_image(f'Image_Deltas/{step}', img_delta_grid, epoch, dataformats='HW')

                if epoch == 1:
                    self.writer.add_image(f'k-space_Targets/{step}', kspace_target_grid, epoch, dataformats='HW')
                    self.writer.add_image(f'Image_Targets/{step}', img_target_grid, epoch, dataformats='HW')

                    # TODO: Add input images to visualization too.

                self.targets_recorded = True

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False)
        return epoch_loss, epoch_metrics

    def _val_step(self, inputs, targets, extra_params):
        outputs = self.model(inputs)
        recons = self.output_transform(outputs, targets, extra_params)
        step_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets'])
        step_metrics = dict()
        return recons, step_loss, step_metrics

    @staticmethod
    def _get_slice_metrics(img_recons, img_targets):

        img_recons = img_recons.detach()  # Just in case.
        img_targets = img_targets.detach()

        max_range = img_targets.max() - img_targets.min()
        slice_ssim = ssim_loss(img_recons, img_targets, max_val=max_range)
        slice_psnr = psnr(img_recons, img_targets, data_range=max_range)
        slice_nmse = nmse(img_recons, img_targets)

        return {'slice_ssim': slice_ssim, 'slice_nmse': slice_nmse, 'slice_psnr': slice_psnr}

    def _get_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(self.val_loader.dataset)

        # Checking for nan values.
        epoch_loss = torch.stack(epoch_loss)
        is_finite = torch.isfinite(epoch_loss)
        num_nans = (is_finite.size(0) - is_finite.sum()).item()

        if num_nans > 0:
            self.logger.warning(f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices.'
                                f'Turning on anomaly detection.')
            # Turn on anomaly detection for finding where the nan values are.
            torch.autograd.set_detect_anomaly(True)
            epoch_loss = torch.mean(epoch_loss[is_finite]).item()
        else:
            epoch_loss = torch.mean(epoch_loss).item()

        for key, value in epoch_metrics.items():
            epoch_metric = torch.stack(value)
            is_finite = torch.isfinite(epoch_metric)
            num_nans = (is_finite.size(0) - is_finite.sum()).item()

            if num_nans > 0:
                self.logger.warning(f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices.'
                                    f'Turning on anomaly detection.')
                epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item()
            else:
                epoch_metrics[key] = torch.mean(epoch_metric).item()

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}')
        for key, value in step_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}')

    def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, '
                         f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch)

        for key, value in epoch_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}')
            self.writer.add_scalar(f'{mode}_epoch_{key}', scalar_value=value, global_step=epoch)

        if not training:  # Record learning rate.
            for idx, group in enumerate(self.optimizer.param_groups, start=1):
                self.writer.add_scalar(f'learning_rate_{idx}', group['lr'], global_step=epoch)
    def __init__(self, log_dir):
        super().__init__()
        self.log_dir = ensure_dir(log_dir)

        self.writer = SummaryWriter(self.log_dir)