Esempio n. 1
0
    def load(self, dataset_dir):
        import glob

        abs_dirname = os.path.join(submit.get_path_from_template(dataset_dir), '*')
        fnames = sorted(glob.glob(abs_dirname))
        if len(fnames) == 0:
            print('\nERROR: No files found using the following glob pattern:', abs_dirname, '\n')
            sys.exit(1)

        images = []
        for fname in fnames:
            try:
                if config.is_image_npy():
                    im = np.load(fname)
                else:
                    if config.get_nb_channels() == 1:
                        im = PIL.Image.open(fname).convert('L')
                    else:
                        im = PIL.Image.open(fname).convert('RGB')

                arr = np.array(im, dtype=np.float32)

                # If only one channel, we can have arr.shape = (256, 256) instead of (1, 256, 256) or (256, 256, 1)
                if len(arr.shape) == 2:
                    reshaped = np.array([arr / 255.0 - 0.5])
                elif config.is_image_npy():
                    reshaped = arr / 255.0 - 0.5
                else:
                    reshaped = arr.transpose([2, 0, 1]) / 255.0 - 0.5

                images.append(reshaped)
            except OSError as e:
                print('Skipping file', fname, 'due to error: ', e)
        self.images = images
Esempio n. 2
0
    def train(args):
        if args:
            n2n = args.noise2noise if 'noise2noise' in args else True
            train_config.noise2noise = n2n
            if 'long_train' in args and args.long_train:
                train_config.iteration_count = 500000
                train_config.eval_interval = 100
                train_config.ramp_down_perc = 0.5
        else:
            print('running with defaults in train_config')
        noise = 'gaussian'
        if 'noise' in args:
            if args.noise not in corruption_types:
                error('Unknown noise type', args.noise)
            else:
                noise = args.noise
        train_config.noise = corruption_types[noise]

        if train_config.noise2noise:
            submit_config.run_desc += "-n2n"
        else:
            submit_config.run_desc += "-n2c"

        if 'train_tfrecords' in args and args.train_tfrecords is not None:
            train_config.train_tfrecords = submit.get_path_from_template(
                args.train_tfrecords)

        print(train_config)
        dnnlib.submission.submit.submit_run(submit_config, **train_config)
Esempio n. 3
0
    def train(args):
        if args:
            if 'long_train' in args and args.long_train:
                train_config.iteration_count = 500000
                train_config.eval_interval = 5000
                train_config.ramp_down_perc = 0.5
        else:
            print ('running with defaults in train_config')
        noise = 'gaussian'

        train_config.train_tfrecords = submit.get_path_from_template(args.tfrecords)

        print (train_config.train_tfrecords)
        dnnlib.submission.submit.submit_run(submit_config, **train_config)
Esempio n. 4
0
    def train(in_args):
        """ Read the arguments given to train and lauch the training

        :param in_args:
        :return:
        """
        # Reading 'noise2noise" and 'long-run' args
        if in_args:
            n2n = in_args.noise2noise if 'noise2noise' in in_args else True
            train_config.noise2noise = n2n
            if 'long_train' in in_args and in_args.long_train:
                train_config.iteration_count = 500000
                train_config.eval_interval = 1000
                train_config.ramp_down_perc = 0.5
        else:
            print('running with defaults in train_config')

        # Reading 'noise' argument
        noise = 'gaussian'
        if 'noise' in in_args:
            if in_args.noise in corruption_types:
                noise = in_args.noise
            else:
                error('Unknown noise type', in_args.noise)
        train_config.noise = corruption_types[noise]

        # Reading type of training : noise 2 noise or noise 2 clean
        # NB : default == noise 2 noise
        if train_config.noise2noise:
            submit_config.run_desc += "-n2n"
        else:
            submit_config.run_desc += "-n2c"

        # Reading the 'tfrecords' directory argument
        if 'train_tfrecords' in in_args and in_args.train_tfrecords is not None:
            train_config.train_tfrecords = submit.get_path_from_template(in_args.train_tfrecords)

        # Reading the validation directory
        val_dir = 'default'
        if 'val_dir' in in_args:
            if in_args.val_dir not in val_datasets:
                error('Unknown validation directory', in_args.val_dir)
            else:
                val_dir = in_args.val_dir
            train_config.validation_config = val_datasets[val_dir]

        # Finally, printing the config and launching the training
        print(train_config)
        dnnlib.submission.submit.submit_run(submit_config, **train_config)
Esempio n. 5
0
    def load(self, dataset_dir):
        import glob

        fnames = sorted(
            glob.glob(
                os.path.join(submit.get_path_from_template(dataset_dir), '*')))
        images = []
        for fname in fnames:
            try:
                im = PIL.Image.open(fname).convert('RGB')
                arr = np.array(im, dtype=np.float32)
                reshaped = arr.transpose([2, 0, 1]) / 255.0 - 0.5
                images.append(reshaped)
            except OSError as e:
                print('Skipping file', fname, 'due to error: ', e)
        self.images = images
    def load(self, dataset_dir):
        import glob

        abs_dirname = os.path.join(submit.get_path_from_template(dataset_dir),
                                   '*')
        fnames = sorted(glob.glob(abs_dirname))
        if len(fnames) == 0:
            print('\nERROR: No files found using the following glob pattern:',
                  abs_dirname, '\n')
            sys.exit(1)

        images = []
        for fname in fnames:
            try:
                im = PIL.Image.open(fname).convert('RGB')
                arr = np.array(im, dtype=np.float32)
                reshaped = arr.transpose([2, 0, 1]) / 255.0 - 0.5
                images.append(reshaped)
            except OSError as e:
                print('Skipping file', fname, 'due to error: ', e)
        self.images = images
Esempio n. 7
0
    def load(self, dataset_dir):
        import glob

        abs_dirname = os.path.join(submit.get_path_from_template(dataset_dir),
                                   '*')
        fnames = sorted(glob.glob(abs_dirname))
        if len(fnames) == 0:
            print('\nERROR: No files found using the following glob pattern:',
                  abs_dirname, '\n')
            sys.exit(1)

        images = []
        for fname in fnames:
            try:
                #im = PIL.Image.open(fname).convert('RGB')
                #arr = np.array(im, dtype=np.float32)
                #arr = np.mean(im, axis=0)
                #reshaped = arr.transpose([2, 0, 1]) / 255.0 - 0.5

                try:
                    im = np.load(fname)
                except:
                    im = PIL.Image.open(fname).convert('RGB')
                if len(np.array(im).shape) == 3:
                    im = np.mean(im, axis=-1)
                arr = np.array(im, dtype=np.float32)
                arr = arr[0:256, 0:256]
                arr = np.expand_dims(arr, axis=2)
                reshaped = arr.transpose([2, 0, 1])
                reshaped = (reshaped - reshaped.min()) / (reshaped.max() -
                                                          reshaped.min()) - 0.5
                images.append(reshaped)

            except OSError as e:
                print('Skipping file', fname, 'due to error: ', e)
        self.images = images
Esempio n. 8
0
def load_snapshot(fname):
    fname = os.path.join(submit.get_path_from_template(fname))
    with open(fname, "rb") as f:
        return pickle.load(f)