Example #1
0
def joint_fit(opt,
              model,
              task_index,
              continued_task_index,  # currently trained task index
              train_datasets,
              val_datasets=None,
              visualizer=None,
              ):
    model.continued_task_index = continued_task_index

    logging.info(f"Fitting task {task_index}")
    best_matrix_item = None
    for epoch in range(opt.epoch_start, opt.n_epochs + opt.n_epochs_decay + 1):  # outer loop for different epochs;
        epoch_start_time = time.time()  # timer for entire epoch

        total_loss = 0
        n_batch = 0

        val_matrix_items = []
        for index, train_dataset in enumerate(train_datasets):

            model.setup(task_index=index, step=2)

            for data in train_dataset:  # inner loop within one epoch
                image, target = data
                image = image.to(opt.device, non_blocking=True)
                target = target.to(opt.device, non_blocking=True)

                multi_output = [None for _ in range(nb_tasks)]
                data: 'image,MultiOutput' = PseudoData(opt, image, target, MultiOutput(multi_output), index)  #

                model.set_data(data)  # unpack _data from dataset and apply preprocessing
                model.train(index)

                total_loss += model.loss_total
                n_batch += 1

            # Validation
            val_matrix, _ = val(val_datasets[index], model, index, visualizer)
            val_matrix_items.append(val_matrix)

        val_matrix = my_sum(val_matrix_items)
        val_matrix = val_matrix / len(val_matrix_items)

        total_loss /= n_batch
        if epoch % opt.curve_freq == 0:  # visualizing training losses and save logging information to the disk
            visualizer.add_losses({'loss_total': total_loss}, epoch)

        if (epoch + 1) % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
            logging.info('saving the model at the end of epoch %d' % (epoch))
            model.save_networks(continued_task_index, epoch)

        if opt.save_best and (best_matrix_item is None or val_matrix > best_matrix_item):
            logging.info(f'saving the best model at the end of epoch {epoch}')
            model.save_networks(continued_task_index, epoch="best")
            best_matrix_item = val_matrix

        logging.info(
            f'End of epoch {epoch} / {opt.n_epochs + opt.n_epochs_decay} \t train_loss={total_loss.detach()},val:{val_matrix}, Time Taken: {time.time() - epoch_start_time} sec')
        model.update_learning_rate()  # update learning rates at the end of every epoch.
Example #2
0
def val(val_dataset: 'Single task_dataset',
        model: BaseModel,
        task_index,
        visualizer=None) -> Tuple[MatrixItem, List]:
    """for validation on one task"""
    logging.info(f"Validating task {task_index}")
    start_time = time.time()  # timer for validate a task

    matrixItems = []
    for i, data in enumerate(val_dataset):  # inner loop within one epoch
        model.set_data(PseudoData(opt, Bunch(**data["data"])))
        model.test(visualizer)
        # Add matrixItem result
        matrixItems.append(model.get_matrix_item(task_index))

    res = my_sum(matrixItems)
    res = res / len(matrixItems)
    logging.info(f"Validation Time Taken: {time.time() - start_time} sec")
    return res, matrixItems
Example #3
0
def main(_):
    # new model
    opt = TrainOptions().parse()  # get training options
    model = Model(opt)  # create a model given opt.model and other options

    train_datasets = create_task_dataset(opt, phase="train")
    val_datasets = create_task_dataset(opt, phase="val")

    task_index = 0
    train_dataset = train_datasets[task_index]
    val_dataset = val_datasets[task_index]
    logging.info(f"Fitting task {task_index}")
    for epoch in range(10):  # outer loop for different epochs;
        epoch_start_time = time.time()  # timer for entire epoch
        total_loss = 0
        n_batch = 0
        for data in train_dataset:  # inner loop within one epoch
            logging.debug(
                f'Loading dataset {data["data_name"]}, target={data["data"]["target"]}'
            )
            model.set_data(Bunch(**data["data"]))
            model.train()
            losses = model.get_current_losses()
            total_loss += losses['loss_total']
            n_batch += 1
        total_loss /= n_batch

        # Validation
        matrixItems = []
        for data in val_dataset:
            model.set_data(Bunch(**data["data"]))
            model.test()
            matrixItems.append(model.get_matrix_item())

        val_matrix = my_sum(matrixItems)
        val_matrix = val_matrix / len(matrixItems)

        # logging output
        logging.info(
            f'End of epoch {epoch} / {opt.n_epochs + opt.n_epochs_decay} \t train_loss={total_loss.detach()},val:{val_matrix}, Time Taken: {time.time() - epoch_start_time} sec'
        )
Example #4
0
def test(opt,
         test_datasets,
         model: BaseModel,
         train_index,
         task_index,
         visualizer=None):
    """test the model on multi-task test_datasets, after training task indexed with <train_index>

	Return
	None
	the global testMatrix will be updated
	"""

    print(
        f'=============================After trained task {train_index - 1}========================================='
    )
    matrixItems = []

    for test_index, test_dataset in enumerate(test_datasets):
        matrixItem, _ = val(
            test_dataset,
            model,
            test_index,
            visualizer,
        )
        test_matrix[(train_index, test_index + 1)] = matrixItem

        if test_index <= task_index:
            matrixItems.append(model.get_matrix_item(task_index))

        print(f'Test in task {test_index}, matrixItem=({matrixItem})')

    if len(matrixItems):
        res = my_sum(matrixItems)
        res = res / len(matrixItems)
        print(f'Average Accuracy is ({res})')

    train_index += 1
    return train_index
Example #5
0
def val(val_dataset: 'Single task_dataset',
        model: BaseModel,
        task_index,
        visualizer=None) -> Tuple[MatrixItem, List]:
    """for validation on one task"""
    logging.info(f"Validating task {task_index}")
    start_time = time.time()  # timer for validate a task

    matrixItems = []
    for i, data in enumerate(val_dataset):  # inner loop within one epoch
        image, target = data
        # logging.debug(f'{image.shape},{target}')
        image = image.to(opt.device, non_blocking=True)
        target = target.to(opt.device, non_blocking=True)

        model.set_data(PseudoData(opt, image, target))
        model.test(visualizer)
        # Add matrixItem result
        matrixItems.append(model.get_matrix_item(task_index))

    res = my_sum(matrixItems)
    res = res / len(matrixItems)
    logging.info(f"Validation Time Taken: {time.time() - start_time} sec")
    return res, matrixItems