예제 #1
0
def test_train_options_model_specified():
    print("-----------------------------------------------------")
    print("Test: test_train_options_model_specified")
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    train_options = TrainOptions()
    parser = train_options.initialize(parser)
    opts, _ = parser.parse_known_args()

    model_name = opts.model
    model_option_setter = models.get_option_setter(model_name)
    parser = model_option_setter(parser, is_train=train_options.isTrain)
    opts, _ = parser.parse_known_args()  # parse again with new defaults

    print_basic_opts(opts)
    print_train_opts(opts)

    model = find_model_using_name(model_name)
    print(model)

    print("------------------- Dataset Config -----------------")
    print(".. --input_means %s" % opts.input_means)
    print(".. --input_size %s" % opts.input_size)
    print(".. --input_range %s" % opts.input_range)
    print(".. --input_channels %s" % opts.input_nc)
    print(".. --input_std %s" % opts.input_std)

    print("-----------------------------------------------------\n\n")
예제 #2
0
def test_train_options_dataset_specified():
    print("-----------------------------------------------------")
    print("Test: test_train_options_dataset_specified")
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    train_options = TrainOptions()
    parser = train_options.initialize(parser)
    opts, _ = parser.parse_known_args()

    dataset_name = opts.dataset_mode
    dataset_option_setter = get_option_setter(dataset_name)
    parser = dataset_option_setter(parser, is_train=train_options.isTrain)
    opts, _ = parser.parse_known_args()  # parse again with new defaults

    print_basic_opts(opts)
    print_train_opts(opts)

    dataset = find_dataset_using_name(dataset_name)
    print(dataset)

    print("------------------- Dataset Config -----------------")
    print(".. --dataset_mode %s" % opts.dataset_mode)
    print(".. --img_name_tmpl %s" % opts.img_name_tmpl)

    print(".. --split_dir %s" % opts.split_dir)

    print("-----------------------------------------------------\n\n")
예제 #3
0
    def __init__(self):
        opt = TrainOptions().parse()  # get training options
        dataset = create_dataset(
            opt)  # create a dataset given opt.dataset_mode and other options
        dataset_size = len(dataset)  # get the number of images in the dataset.
        print('The number of training images = %d' % dataset_size)

        model = create_model(
            opt)  # create a model given opt.model and other options
        model.prepare_model(
            opt)  # regular setup: load and print networks; create schedulers
        total_iters = 0  # the total number of training iterations

        for epoch in range(epochs):
            epoch_start_time = time.time()  # timer for entire epoch
            iter_data_time = time.time(
            )  # timer for data loading per iteration
            epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

            for i, data in enumerate(dataset):  # inner loop within one epoch
                iter_start_time = time.time(
                )  # timer for computation per iteration
                if total_iters % opt.print_freq == 0:
                    t_data = iter_start_time - iter_data_time
                # visualizer.reset()
                total_iters += opt.batch_size
                epoch_iter += opt.batch_size
                model(
                    data
                )  # calculate loss functions, get gradients, update network weights

                # if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                #     save_result = total_iters % opt.update_html_freq == 0
                # model.compute_visuals()
                # visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

                # if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                #     losses = model.get_current_losses()
                #     t_comp = (time.time() - iter_start_time) / opt.batch_size
                # # visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                # if opt.display_id > 0:
                #     visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

                iter_data_time = time.time()

            if epoch % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
                print('saving the model at the end of epoch %d, iters %d' %
                      (epoch, total_iters))
예제 #4
0
import argparse

import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from torchsummary import summary

import pretrainedmodels

from core.dataset import create_dataset_loader
from core.options.train_options import TrainOptions


opts = TrainOptions().parse()   # get training options
test_loader = create_dataset_loader(opts, phase='train')    # create train dataset given opt.dataset_mode and other options

test_dataset_size = len(test_loader)    # get the number of images in the dataset.

print('The number of testing images = %d' % test_dataset_size)

# training options
batch_size = opts.batch_size


# printing to widget
print("----------------------------------------------------------------------------------------")
print("TEST DATA LOADER:")
print("Dataset: %s" % test_loader.dataset.__class__.__name__)
print("Modality: %s" % opts.modality)
예제 #5
0
def test_train_option_initialise():
    print("-----------------------------------------------------")
    print("Test: test_train_option_initialise")
    train_options = TrainOptions()
    options_initialise_default(train_options)
    print("-----------------------------------------------------\n\n")
예제 #6
0
def test_options_parse():
    print("-----------------------------------------------------")
    print("Test: test_train_options_model_specified")

    opts = TrainOptions().parse()
    print("-----------------------------------------------------")