Ejemplo n.º 1
0
def main(config, resume):
    supercomputer = config[
        'super_computer'] if 'super_computer' in config else False
    #set_procname(config['name'])
    #np.random.seed(1234) I don't have a way of restarting the DataLoader at the same place, so this makes it totaly random
    train_logger = Logger()

    split = config['split'] if 'split' in config else 'train'
    data_loader, valid_data_loader = getDataLoader(config, split)
    #valid_data_loader = data_loader.split_validation()

    model = eval(config['arch'])(config['model'])
    if 'style' in config['model'] and 'lookup' in config['model']['style']:
        model.style_extractor.add_authors(data_loader.dataset.authors)  ##HERE
    model.summary()
    if config['trainer']['class'] == 'HWRWithSynthTrainer':
        gen_model = model
        model = model.hwr
        gen_model.hwr = None
        #config['gen_model$'] = gen_model
    if type(config['loss']) == dict:
        loss = {}  #[eval(l) for l in config['loss']]
        for name, l in config['loss'].items():
            loss[name] = eval(l)
    else:
        loss = eval(config['loss'])
    if type(config['metrics']) == dict:
        metrics = {}
        for name, m in config['metrics'].items():
            metrics[name] = [eval(metric) for metric in m]
    else:
        metrics = [eval(metric) for metric in config['metrics']]

    if 'class' in config['trainer']:
        trainerClass = eval(config['trainer']['class'])
    else:
        trainerClass = Trainer
    trainer = trainerClass(model,
                           loss,
                           metrics,
                           resume=resume,
                           config=config,
                           data_loader=data_loader,
                           valid_data_loader=valid_data_loader,
                           train_logger=train_logger)
    if config['trainer']['class'] == 'HWRWithSynthTrainer':
        trainer.gen = gen_model

    name = config['name']

    def handleSIGINT(sig, frame):
        trainer.save()
        sys.exit(0)

    signal.signal(signal.SIGINT, handleSIGINT)

    print("Begin training")
    trainer.train()
    def __init__(self,
                 model,
                 loss,
                 metrics,
                 resume,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 train_logger=None):
        super(HWRWithSynthTrainer, self).__init__(model, loss, metrics, resume,
                                                  config, train_logger)
        self.config = config
        if 'loss_params' in config:
            self.loss_params = config['loss_params']
        else:
            self.loss_params = {}
        for lossname in self.loss:
            if lossname not in self.loss_params:
                self.loss_params[lossname] = {}
        self.lossWeights = config[
            'loss_weights'] if 'loss_weights' in config else {
                "auto": 1,
                "recog": 1
            }
        self.batch_size = data_loader.batch_size
        self.data_loader = data_loader
        self.data_loader_iter = iter(data_loader)
        self.valid_data_loader = valid_data_loader
        self.valid = True if self.valid_data_loader is not None else False

        char_set_path = config['data_loader']['char_file']
        with open(char_set_path) as f:
            char_set = json.load(f)
        self.idx_to_char = {}
        self.num_class = len(char_set['idx_to_char']) + 1
        for k, v in char_set['idx_to_char'].items():
            self.idx_to_char[int(k)] = v

        if 'synth_data' in config:
            config_synth = {
                'data_loader': config['synth_data'],
                'validation': {}
            }
            self.synth_data_loader, _ = getDataLoader(config_synth, 'train')
            self.authors_of_interest = self.synth_data_loader.dataset.authors_of_interest
        else:
            self.authors_of_interest = None
        self.synth_data_loader_iter = None
def main(config, resume):
    set_procname(config['name'])
    #np.random.seed(1234) I don't have a way of restarting the DataLoader at the same place, so this makes it totaly random
    train_logger = Logger()

    split = config['split'] if 'split' in config else 'train'
    data_loader, valid_data_loader = getDataLoader(config, split)
    #valid_data_loader = data_loader.split_validation()

    model = eval(config['arch'])(config['model'])
    model.summary()
    if type(config['loss']) == dict:
        loss = {}  #[eval(l) for l in config['loss']]
        for name, l in config['loss'].items():
            loss[name] = eval(l)
    else:
        loss = eval(config['loss'])
    if type(config['metrics']) == dict:
        metrics = {}
        for name, m in config['metrics'].items():
            metrics[name] = [eval(metric) for metric in m]
    else:
        metrics = [eval(metric) for metric in config['metrics']]

    if 'class' in config['trainer']:
        trainerClass = eval(config['trainer']['class'])
    else:
        trainerClass = Trainer
    trainer = trainerClass(model,
                           loss,
                           metrics,
                           resume=resume,
                           config=config,
                           data_loader=data_loader,
                           valid_data_loader=valid_data_loader,
                           train_logger=train_logger)

    def handleSIGINT(sig, frame):
        trainer.save()
        sys.exit(0)

    signal.signal(signal.SIGINT, handleSIGINT)

    print("Begin training")
    trainer.train()
Ejemplo n.º 4
0
    def test_getDataLoaderBatch2(self):
        #TODO: Exact value matches for vertices, normals, edges, proj_gt

        self.params.batch_size = 2
        max_vertices, feature_size, data_size, max_total_vertices = data_loader.getMetaData(
            self.params, self.data_dir)
        generator = data_loader.getDataLoader(self.params, self.data_dir,
                                              max_total_vertices, feature_size)
        train_data, train_data_normal, edges, image_feat, proj_gt = next(
            generator)
        self.assertEqual(
            train_data.shape,
            (self.params.batch_size, max_total_vertices, self.params.dim_size))
        self.assertEqual(
            train_data_normal.shape,
            (self.params.batch_size, max_total_vertices, self.params.dim_size))
        self.assertEqual(edges.shape, (self.params.batch_size, ))
        self.assertEqual(image_feat.shape,
                         (self.params.batch_size, feature_size))
        self.assertEqual(proj_gt.shape,
                         (self.params.batch_size, self.params.img_width))
Ejemplo n.º 5
0
   logger = log_utils.Logger(cfg, log_path=os.path.join(cfg.log_path, log_name), model_name="pix2pix")
   if mode == "train":
       if not os.path.exists(cfg.model_root):
          os.makedirs(cfg.model_root)  
       if not os.path.exists(cfg.log_path):   
          os.makedirs(cfg.log_path)
      
   log_utils.setLogger(cfg.log_file)
   print("Logging file", cfg.log_file)
   printConfig(cfg)

      
   #load data 
   dloader = {}
   if mode == "train":
     dloader["train"] = dl.getDataLoader(cfg, "train")
     val_cfg = config.ConfigTest("test")
     dloader["test"] = dl.getDataLoader(val_cfg, "test")
   else: #test mode
     dloader["test"] = dl.getDataLoader(cfg, "test") 
   #testDataLoad(dloader)
   #print(cfg.load_saved_D, cfg.load_saved_G)
   #exit(0)
   
   # Train
   pix_gan = pix2pix.Pix2Pix(cfg, logger)
   if mode == "train":
      pix_gan.trainModel(dloader)
   else:
      pix_gan.testModel(dloader["test"], cfg.ep_cnt) 
      
def main(resume,
         saveDir,
         gpu=None,
         config=None,
         addToConfig=None,
         fromDataset=True,
         test=False,
         arguments=None,
         style_loc=None):
    np.random.seed(1234)
    torch.manual_seed(1234)
    if resume is not None:
        checkpoint = torch.load(resume,
                                map_location=lambda storage, location: storage)
        print('loaded iteration {}'.format(checkpoint['iteration']))
        ##HACK fix
        keys = list(checkpoint['state_dict'].keys())
        for key in keys:
            if 'style_from_normal' in key:  #HACK
                del checkpoint['state_dict'][key]
        if config is None:
            config = checkpoint['config']
        else:
            config = json.load(open(config))
        for key in config.keys():
            if 'pretrained' in key:
                config[key] = None
    else:
        checkpoint = None
        config = json.load(open(config))
    config['optimizer_type'] = "none"
    config['trainer']['use_learning_schedule'] = False
    config['trainer']['swa'] = False
    if gpu is None:
        config['cuda'] = False
    else:
        config['cuda'] = True
        config['gpu'] = gpu
    addDATASET = False
    if addToConfig is not None:
        for add in addToConfig:
            addTo = config
            printM = 'added config['
            for i in range(len(add) - 2):
                addTo = addTo[add[i]]
                printM += add[i] + ']['
            value = add[-1]
            if value == "":
                value = None
            else:
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
            addTo[add[-2]] = value
            printM += add[-2] + ']={}'.format(value)
            print(printM)
            if (add[-2] == 'useDetections'
                    or add[-2] == 'useDetect') and value != 'gt':
                addDATASET = True

    if fromDataset:
        config['data_loader']['batch_size'] = 1
        config['validation']['batch_size'] = 1
        if not test:
            data_loader, valid_data_loader = getDataLoader(config, 'train')
        else:
            config['data_loader']['a_batch_size'] = 1
            config['validation']['a_batch_size'] = 1
            print('changed a_batch_size to 1')
            test_data_loader, _ = getDataLoader(config, 'test')
            valid_data_loader = test_data_loader

    if checkpoint is not None:
        if 'state_dict' in checkpoint:
            model = eval(config['arch'])(config['model'])
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model = checkpoint['model']
    else:
        model = eval(config['arch'])(config['model'])
    model.eval()
    model.summary()
    if gpu is not None:
        model = model.to(gpu)
    model.count_std = 0
    model.dup_std = 0

    gt_mask = 'create_mask' not in config[
        'model']  #'mask' in config['model']['generator'] or 'Mask' in config['model']['generator']

    char_set_path = config['data_loader']['char_file']
    with open(char_set_path) as f:
        char_set = json.load(f)
    char_to_idx = char_set['char_to_idx']

    by_author_styles = defaultdict(list)
    by_author_all_ids = defaultdict(set)
    #style_loc = config['style_loc'] if 'style_loc' in config else style_loc
    if style_loc is not None:
        if style_loc[-1] != '*':
            style_loc += '*'
        all_style_files = glob(style_loc)
        assert (len(all_style_files) > 0)
        for loc in all_style_files:
            #print('loading '+loc)
            with open(loc, 'rb') as f:
                styles = pickle.load(f)
            if 'ids' in styles:
                for i in range(len(styles['authors'])):
                    by_author_styles[styles['authors'][i]].append(
                        (styles['styles'][i], styles['ids'][i]))
                    by_author_all_ids[styles['authors'][i]].update(
                        styles['ids'][i])
            else:
                for i in range(len(styles['authors'])):
                    by_author_styles[styles['authors'][i]].append(
                        (styles['styles'][i], None))

        styles = defaultdict(list)
        authors = set()
        for author in by_author_styles:
            for style, ids in by_author_styles[author]:
                styles[author].append(style)
            if len(styles[author]) > 0:
                authors.add(author)
        authors = list(authors)
    elif not test:
        authors = valid_data_loader.dataset.authors
        styles = None
    else:
        styles = None

    num_char = config['model']['num_class']
    use_hwr_pred_for_style = config['trainer'][
        'use_hwr_pred_for_style'] if 'use_hwr_pred_for_style' in config[
            'trainer'] else False

    charSpec = model.char_style_dim > 0

    with torch.no_grad():
        while True:
            if arguments is None:
                action = input(
                    'indexes/random interp/vae random/strech/author display/math/turk gen/from-to/umap-images/Random styles/help/quit? '
                )  #indexes/random/vae/strech/author-list/quit
            else:
                action = arguments['choice']
                arguments['choice'] = 'q'
            if action == 'done' or action == 'exit' or 'action' == 'quit' or action == 'q':
                exit()
            elif action[0] == 'h':  #help
                print('Options:')
                print('[a] show author ids')
                print(
                    '[r] random interpolation: selects n styles (dataset extracted) and interpolated between them in a circlular pattern'
                )
                print(
                    '[v] same as above, but styles are randomly sampled from guassian distribution (for VAE)'
                )
                print(
                    '[s] strech: manipulate the 1d text encoding to interpolate horizontal streching'
                )
                print(
                    '[m] vector math: perform vector math with style vectors. Use "+" and "-". Use author id to specifiy vector, and [id1,id2,...] to average vectors together.'
                )
                print(
                    '[t] MTurk gen: routine used to generate data for MTurk experimenet'
                )
                print(
                    '[R] Random: generate n images using random (interpolated) styles. Can use fixed or random text'
                )
                print(
                    "[f] Given two image paths, interpolate from one style to the other using the given text."
                )
            elif action == 'a' or action == 'authors':
                print(authors)
            elif action == 's' or action == 'strech':
                index1 = input("batch? ")
                if len(index1) > 0:
                    index1 = int(index1)
                else:
                    index1 = 0
                for i, instance1 in enumerate(valid_data_loader):
                    if i == index1:
                        break
                author1 = instance1['author'][0]
                style1 = get_style(config, model, instance1, gpu)
                image = instance1['image']
                label = instance1['label']
                if gpu is not None:
                    image = image.to(gpu)
                    label = label.to(gpu)
                pred = model.hwr(image, None)
                if use_hwr_pred_for_style:
                    spaced_label = pred
                else:
                    spaced_label = model.correct_pred(pred, label)
                    spaced_label = model.onehot(spaced_label)
                images = interpolate_horz(model, style1, spaced_label)
                for b in range(images[0].size(0)):
                    for i in range(len(images)):
                        genStep = ((1 - images[i][b].permute(1, 2, 0)) *
                                   127.5).cpu().numpy().astype(np.uint8)
                        path = os.path.join(saveDir,
                                            'gen{}_{}.png'.format(b, i))
                        cv2.imwrite(path, genStep)
            elif action[0] == 'r' or action[
                    0] == 'v':  #interpolate randomly selected styles, "v" is VAE
                num_styles = int(input('number of styles? '))
                step = float(input('step (0.1 is normal)? '))
                text = input('text? ')
                if len(text) == 0:
                    text = 'The quick brown fox jumps over the lazy dog.'
                stylesL = []
                if action[0] == 'r':
                    index = random.randint(0, 20)
                    last_author = None
                    for i, instance in enumerate(valid_data_loader):
                        author = instance['author'][0]
                        if i >= index and author != last_author:
                            print('i: {}, a: {}'.format(i, author))
                            image = instance['image'].to(gpu)
                            label = instance['label'].to(gpu)
                            a_batch_size = instance['a_batch_size']
                            style = model.extract_style(
                                image, label, a_batch_size)[::a_batch_size]
                            stylesL.append(style)
                            last_author = author
                            index += random.randint(20, 50)
                            print('next index: {}'.format(index))
                        if len(stylesL) >= num_styles:
                            break
                else:  #VAE
                    stylesL = [
                        torch.FloatTensor(1, model.style_dim).normal_()
                        for i in range(num_styles)
                    ]
                images = []
                styles = []
                #step=0.05
                for i in range(num_styles - 1):
                    b_im, b_sty = interpolate(model, stylesL[i].to(gpu),
                                              stylesL[i + 1].to(gpu), text,
                                              char_to_idx, gpu, step)
                    images += b_im
                    styles += b_sty
                b_im, b_sty = interpolate(model, stylesL[-1].to(gpu),
                                          stylesL[0].to(gpu), text,
                                          char_to_idx, gpu, step)
                images += b_im
                styles += b_sty
                for b in range(images[0].size(0)):
                    for i in range(len(images)):
                        genStep = ((1 - images[i][b].permute(1, 2, 0)) *
                                   127.5).cpu().numpy().astype(np.uint8)
                        if step == 0.2 and i % 5 == 0:
                            genStep[0, :] = 0
                            genStep[-1, :] = 0
                            genStep[:, 0] = 0
                            genStep[:, -1] = 0
                        path = os.path.join(saveDir,
                                            'gen{}_{}.png'.format(b, i))
                        #print('wrote: {}'.format(path))
                        cv2.imwrite(path, genStep)
                    torch.save(styles,
                               os.path.join(saveDir, 'styles{}.pth'.format(b)))

            elif action[
                    0] == 'R':  #Just random (interpolated) styles, with option for random text
                assert (styles is not None and
                        'perhaps you forgot to set "-s path/to/styles.pkl"?')
                num_inst = int(input('num to gen? '))
                text = input(
                    'text? (enter "RANDOM" or file path (.txt) for sampled text"): '
                )
                if len(text) == 0:
                    text = 'The quick brown fox jumps over the lazy dog.'
                    textList = None
                elif text == 'RANDOM':
                    text = None
                    textData = TextData(batch_size=num_inst, max_len=55)
                    textList = textData.getInstance()['gt']
                elif text.endswith('.txt'):
                    textData = TextData(batch_size=num_inst,
                                        max_len=55,
                                        textfile=text)
                    textList = textData.getInstance()['gt']
                    text = None
                else:
                    textList = None

                #sample the styles
                stylesL = []
                textL = []
                text_falseL = []
                for i in range(num_inst):
                    if not model.vae:
                        authorA = random.choice(authors)
                        instance = random.randint(0, len(styles[authorA]) - 1)
                        style1 = styles[authorA][instance]
                        authorB = random.choice(authors)
                        instance = random.randint(0, len(styles[authorB]) - 1)
                        style2 = styles[authorB][instance]

                        #inter = random.random()
                        inter = 2 * random.random() - 0.5
                        if charSpec:
                            style = (style1[0] * inter + style2[0] *
                                     (1 - inter), style1[1] * inter +
                                     style2[1] * (1 - inter),
                                     style1[2] * inter + style2[2] *
                                     (1 - inter))
                        else:
                            style = style1 * inter + style2 * (1 - inter)

                        stylesL.append(style)
                    else:  #VAE
                        stylesL = [
                            torch.FloatTensor(1, model.style_dim).normal_()
                            for i in range(num_styles)
                        ]
                ensure_dir(os.path.join(saveDir))  #,'fake'))
                for i, style in enumerate(stylesL):
                    if charSpec:
                        if gpu is not None:
                            style = (torch.from_numpy(style[0])[None,
                                                                ...].to(gpu),
                                     torch.from_numpy(style[1][None,
                                                               ...]).to(gpu),
                                     torch.from_numpy(style[2][None,
                                                               ...]).to(gpu))
                        else:
                            style = (torch.from_numpy(style[0])[None, ...],
                                     torch.from_numpy(style[1])[None, ...],
                                     torch.from_numpy(style[2])[None, ...])
                    else:
                        if gpu is not None:
                            style = torch.from_numpy(style).to(gpu)
                        else:
                            style = torch.from_numpy(style)
                        style = style[None, ...]

                    if textList is not None:
                        text = textList[i]
                    im = generate(model, style, text, char_to_idx, gpu)
                    im = ((1 - im[0].permute(1, 2, 0)) *
                          127.5).cpu().numpy().astype(np.uint8)
                    image_name = 'sample_{}.png'.format(i)
                    path = os.path.join(saveDir, image_name)
                    cv2.imwrite(path, im)

            elif action[0] == 'm':  #style vector math, this is broken
                assert (styles is not None and
                        'perhaps you forgot to set "-s path/to/styles.pkl"?')

                text = input('text? ')
                if len(text) == 0:
                    text = 'The quick brown fox jumps over the lazy dog.'
                print('elements of expression: author_id,+,-,[author_id')
                expression = input('expression? ')
                idx = 0
                #style=torch.FloatTensor(1,model.style_dim).zero_()
                m = re.search(r'^(\d+|\+|-|\[[^-\+]+\])', expression[idx:])
                segment = m[0]
                idx += len(segment)
                if segment[0] == '[':
                    nums = [int(s) for s in segment[1:-1].split(',')]
                    style = styles[nums[0]]
                    for num in nums[1:]:
                        subStyle += styles[num]
                    style /= len(nums)
                else:
                    #if normal:
                    style = styles[segment][0]
                    #else:
                    #    style=styles[int(segment)]
                while idx < len(expression):
                    m = re.search(r'^(\d+|\+|-|\[[^-\+]+\])', expression[idx:])
                    operation = m[0]
                    idx += len(operation)

                    m = re.search(r'^(\d+|\+|-|\[[^-\+]+\])', expression[idx:])
                    segment = m[0]
                    idx += len(segment)
                    if segment[0] == '[':
                        nums = [int(s) for s in segment[1:-1].split(',')]
                        subStyle = styles[nums[0]]
                        for num in nums[1:]:
                            subStyle += styles[num]
                        subStyle /= len(nums)
                    else:
                        subStyle = styles[int(segment)]
                    if operation == '+':
                        style += subStyle
                    elif operation == '-':
                        style += subStyleS

                #if normal:
                if type(style) is list:
                    if gpu is not None:
                        style = (torch.from_numpy(style[0])[None, ...].to(gpu),
                                 torch.from_numpy(style[1][None, ...]).to(gpu),
                                 torch.from_numpy(style[2][None, ...]).to(gpu))
                    else:
                        style = (torch.from_numpy(style[0])[None, ...],
                                 torch.from_numpy(style[1])[None, ...],
                                 torch.from_numpy(style[2])[None, ...])
                else:
                    if gpu is not None:
                        style = torch.from_numpy(style).to(gpu)
                    else:
                        style = torch.from_numpy(style)
                    style = style[None, ...]
                #else:
                #    style=style.to(gpu)

                im = generate(model, style.to(gpu), text, char_to_idx, gpu)
                im = ((1 - im[0].permute(1, 2, 0)) *
                      127.5).cpu().numpy().astype(np.uint8)
                path = os.path.join(saveDir, 'result.png')
                cv2.imwrite(path, im)

            elif action == 'A':  #average an authors style vectors together
                author = input("author? ")
                text = input("text? ")
                if len(text) == 0:
                    text = 'The quick brown fox jumps over the lazy dog.'
                max_hits = input("max instances? ")
                if len(max_hits) > 0:
                    max_hits = int(max_hits)
                else:
                    max_hits = 5
                styles = []
                for i, instance1 in enumerate(data_loader):
                    if instance1['author'][0] == author:
                        print('{} found on instance {}'.format(author, i))
                        label1 = instance1['label'].to(gpu)
                        image1 = instance1['image'].to(gpu)
                        a_batch_size = instance1['a_batch_size']
                        styles.append(
                            model.extract_style(image1, label1,
                                                a_batch_size)[::a_batch_size])
                        max_hits -= 1
                        if max_hits <= 0:
                            break
                styles = torch.cat(styles, dim=0)
                style = styles.mean(dim=0)[None, ...]
                im = generate(model, style, text, char_to_idx, gpu)
                im = ((1 - im[0].permute(1, 2, 0)) *
                      127.5).cpu().numpy().astype(np.uint8)
                path = os.path.join(saveDir, 'gen_{}.png'.format(author))
                cv2.imwrite(path, im)

            elif action[0] == 't':  #generate random samples for MTurk test.
                start_index = 0
                assert (styles is not None and 'use -a style_loc')
                if arguments is None:
                    num_inst = input('number of instances? ')
                else:
                    num_inst = arguments['num_inst']
                    if 'start_index' in arguments:  #this option is to start the image indexing later, so I could easily add the poorly generated images to the main set more easily
                        start_index = int(arguments['start_index'])
                num_inst = int(num_inst)

                if arguments is None:
                    interpolateS = input(
                        'interpolate? [Y]/N: '
                    )  #whether to interpolate styles, or take the directly from images
                elif 'interpolate' in arguments:
                    interpolateS = arguments['interpolate']
                else:
                    interpolateS = 'Y'
                interpolateS = interpolateS != 'N' and interpolateS != 'n'

                false_full = True

                stylesL = []
                textL = []
                text_falseL = []
                #first build a list of styles
                for i in range(num_inst):
                    if not model.vae:
                        authorA = random.choice(authors)
                        instance = random.randint(0, len(styles[authorA]) - 1)
                        style1 = styles[authorA][instance]
                        if interpolateS:
                            authorB = random.choice(authors)
                            instance = random.randint(0,
                                                      len(styles[authorB]) - 1)
                            style2 = styles[authorB][instance]

                            inter = random.random()
                            if charSpec:
                                style = (style1[0] * inter + style2[0] *
                                         (1 - inter), style1[1] * inter +
                                         style2[1] * (1 - inter),
                                         style1[2] * inter + style2[2] *
                                         (1 - inter))
                            else:
                                style = style1 * inter + style2 * (1 - inter)
                        else:
                            style = style1

                        stylesL.append(style)
                    else:  #VAE
                        stylesL = [
                            torch.FloatTensor(1, model.style_dim).normal_()
                            for i in range(num_styles)
                        ]
                images = []
                ensure_dir(os.path.join(saveDir))
                to_write = []
                with open(os.path.join(saveDir, 'text.csv'), 'w') as text_out:
                    #text.csv is the data for MTurk

                    #save the real images, from test set
                    for i in range(num_inst):
                        index = random.randint(0, len(test_data_loader) - 1)
                        instance = test_data_loader.dataset[index]
                        text = instance['gt'][0]
                        textL.append(text)
                        while (True):
                            indexF = random.randint(0,
                                                    len(test_data_loader) - 1)
                            if indexF != index:
                                break
                        instanceF = test_data_loader.dataset[indexF]
                        textF = instanceF['gt'][0]
                        textF = permuteF(re.sub(r'[^\w\s]', '', text))
                        im = ((1 - instance['image'][0].permute(1, 2, 0)) *
                              127.5).cpu().numpy().astype(np.uint8)
                        image_name = 'sample_{}.png'.format(i + start_index)
                        path = os.path.join(saveDir, image_name)
                        cv2.imwrite(path, im)
                        url = 'http://students.cs.byu.edu/~brianld/images/{}'.format(
                            image_name)

                        to_write.append([
                            url,
                            re.sub(r'[^\w\s]', '', text), textF, image_name,
                            'real'
                        ])

                    random.shuffle(textL)
                    #save the fake generated images
                    for i, (style, text) in enumerate(zip(stylesL, textL)):
                        if charSpec:
                            if gpu is not None:
                                style = (torch.from_numpy(
                                    style[0])[None, ...].to(gpu),
                                         torch.from_numpy(
                                             style[1][None, ...]).to(gpu),
                                         torch.from_numpy(
                                             style[2][None, ...]).to(gpu))
                            else:
                                style = (torch.from_numpy(style[0])[None, ...],
                                         torch.from_numpy(style[1])[None, ...],
                                         torch.from_numpy(style[2])[None, ...])
                        else:
                            if gpu is not None:
                                style = torch.from_numpy(style).to(gpu)
                            else:
                                style = torch.from_numpy(style)
                            style = style[None, ...]

                        im = generate(model, style, text, char_to_idx, gpu)
                        im = ((1 - im[0].permute(1, 2, 0)) *
                              127.5).cpu().numpy().astype(np.uint8)
                        image_name = 'sample_{}.png'.format(i + num_inst +
                                                            start_index)
                        path = os.path.join(saveDir, image_name)
                        cv2.imwrite(path, im)
                        url = 'http://students.cs.byu.edu/~brianld/images/{}'.format(
                            image_name)
                        text = re.sub(r'[^\w\s]', '', text)
                        textF = permuteF(text)
                        to_write.append(
                            [url, text, textF, image_name, 'generated'])

                    random.shuffle(to_write)
                    csvwriter = csv.writer(text_out,
                                           delimiter=',',
                                           quotechar='"',
                                           quoting=csv.QUOTE_MINIMAL)
                    csvwriter.writerow([
                        'image_url', 'real_text', 'false_text', 'image_name',
                        'type'
                    ])
                    for l in to_write:
                        csvwriter.writerow(l)

            elif action[
                    0] == 'f':  #from given image's style to other image's style.
                if arguments is None:
                    path1 = input("image path 1? ")
                    if len(path1) == 0:
                        path1 = 'real1.png'
                    path2 = input("image path 2? ")
                    if len(path2) == 0:
                        path2 = 'real2.png'
                    text_gen = input("text to generate? ")
                else:
                    path1 = arguments['path1']
                    path2 = arguments['path2']
                    text_gen = arguments[
                        'text_gen'] if 'text_gen' in arguements else arguments[
                            'text']
                img_height = 64

                image1 = cv2.imread(path1, 0)
                if image1.shape[0] != img_height:
                    percent = float(img_height) / image1.shape[0]
                    image1 = cv2.resize(image1, (0, 0),
                                        fx=percent,
                                        fy=percent,
                                        interpolation=cv2.INTER_CUBIC)
                image1 = image1[..., None]
                image1 = image1.astype(np.float32)
                image1 = 1.0 - image1 / 128.0
                image1 = image1.transpose([2, 0, 1])
                image1 = torch.from_numpy(image1)
                if gpu is not None:
                    image1 = image1.to(gpu)

                image2 = cv2.imread(path2, 0)
                if image2.shape[0] != img_height:
                    percent = float(img_height) / image2.shape[0]
                    image2 = cv2.resize(image2, (0, 0),
                                        fx=percent,
                                        fy=percent,
                                        interpolation=cv2.INTER_CUBIC)
                image2 = image2[..., None]
                image2 = image2.astype(np.float32)
                image2 = 1.0 - image2 / 128.0
                image2 = image2.transpose([2, 0, 1])
                image2 = torch.from_numpy(image2)
                if gpu is not None:
                    image2 = image2.to(gpu)

                min_width = min(image1.size(2), image2.size(2))
                style = model.extract_style(
                    torch.stack(
                        (image1[:, :, :min_width], image2[:, :, :min_width]),
                        dim=0), None, 1)
                if type(style) is tuple:
                    style1 = (style[0][0:1], style[1][0:1], style[2][0:1])
                    style2 = (style[0][1:2], style[1][1:2], style[2][1:2])
                else:
                    style1 = style[0:1]
                    style2 = style[1:2]

                images, stylesInter = interpolate(model, style1, style2,
                                                  text_gen, char_to_idx, gpu)

                for b in range(images[0].size(0)):
                    for i in range(len(images)):
                        genStep = ((1 - images[i][b].permute(1, 2, 0)) *
                                   127.5).cpu().numpy().astype(np.uint8)
                        path = os.path.join(saveDir,
                                            'gen{}_{}.png'.format(b, i))
                        #print('wrote: {}'.format(path))
                        cv2.imwrite(path, genStep)

            elif action[
                    0] == 'u':  #Umap images, and image for every style, this was to replicate figure in GANWriting paper, but we didn't end up using it. May not work
                per_author = 3
                text = 'deep'
                with open(os.path.join(saveDir, 'ordered.txt'), 'w') as f:
                    f.write('{}\n'.format(per_author))
                    for author in authors:
                        for i in range(per_author):
                            style = styles[author][i]
                            if charSpec:
                                if gpu is not None:
                                    style = (torch.from_numpy(
                                        style[0])[None, ...].to(gpu),
                                             torch.from_numpy(
                                                 style[1][None, ...]).to(gpu),
                                             torch.from_numpy(
                                                 style[2][None, ...]).to(gpu))
                                else:
                                    style = (torch.from_numpy(style[0])[None,
                                                                        ...],
                                             torch.from_numpy(style[1])[None,
                                                                        ...],
                                             torch.from_numpy(style[2])[None,
                                                                        ...])
                            else:
                                if gpu is not None:
                                    style = torch.from_numpy(style).to(gpu)
                                else:
                                    style = torch.from_numpy(style)
                                style = style[None, ...]
                            im = generate(model, style, text, char_to_idx, gpu)
                            im = ((1 - im[0].permute(1, 2, 0)) *
                                  127.5).cpu().numpy().astype(np.uint8)
                            image_name = '{}_{}.png'.format(author, i)
                            path = os.path.join(saveDir, image_name)
                            cv2.imwrite(path, im)
                            f.write('{}\n'.format(path))

            else:  #if action=='i' or action=='interpolate':
                if fromDataset and styles is None:
                    index1 = input("batch? ")
                    if len(index1) > 0:
                        index1 = int(index1)
                    else:
                        index1 = 0
                    if index1 >= 0:
                        data = valid_data_loader
                    else:
                        index1 *= -1
                        data = data_loader
                    for i, instance1 in enumerate(data):
                        if i == index1:
                            break
                    author1 = instance1['author'][0]
                    print('author: {}'.format(author1))
                else:
                    author1 = input("author? ")
                    if len(author1) == 0:
                        author1 = authors[0]
                if True:  #new way
                    mask = None
                    index = input("batch? ")
                    text = input("text? ")
                    if len(text) == 0:
                        text = 'The quick brown fox jumps over the lazy dog.'
                    if len(index) > 0:
                        index = int(index)
                    else:
                        index = 0
                    if index >= 0:
                        data = valid_data_loader
                    else:
                        index *= -1
                        data = data_loader
                    for i, instance2 in enumerate(data):
                        if i == index:
                            break
                    author2 = instance2['author'][0]
                    print('author: {}'.format(author2))
                    image1 = instance1['image'].to(gpu)
                    label1 = instance1['label'].to(gpu)
                    image2 = instance2['image'].to(gpu)
                    label2 = instance2['label'].to(gpu)
                    a_batch_size = instance1['a_batch_size']
                    #spaced_label = correct_pred(pred,label)
                    #spaced_label = onehot(spaced_label,num_char)
                    if styles is not None:
                        style1 = styles[author1][0]
                        style2 = styles[author2][0]
                        style1 = torch.from_numpy(style1)
                        style2 = torch.from_numpy(style2)
                    else:
                        style1 = model.extract_style(
                            image1, label1, a_batch_size)[::a_batch_size]
                        style2 = model.extract_style(
                            image2, label2, a_batch_size)[::a_batch_size]
                    images, stylesInter = interpolate(model, style1, style2,
                                                      text, char_to_idx, gpu)

                if mask is not None:
                    mask = ((mask.cpu().permute(0, 2, 3, 1) + 1) / 2.0).numpy()
                for b in range(images[0].size(0)):
                    for i in range(len(images)):
                        genStep = ((1 - images[i][b].permute(1, 2, 0)) *
                                   127.5).cpu().numpy().astype(np.uint8)
                        path = os.path.join(saveDir,
                                            'gen{}_{}.png'.format(b, i))
                        cv2.imwrite(path, genStep)

                    if mask is not None:
                        path_mask = os.path.join(saveDir,
                                                 'mask{}.png'.format(b))
                        cv2.imwrite(path_mask, mask[b])
def main(resume,
         saveDir,
         numberOfImages,
         index,
         gpu=None,
         shuffle=False,
         setBatch=None,
         config=None,
         thresh=None,
         addToConfig=None,
         test=False,
         verbose=2):
    np.random.seed(1234)
    torch.manual_seed(1234)
    if resume is not None:
        checkpoint = torch.load(resume,
                                map_location=lambda storage, location: storage)
        print('loaded iteration {}'.format(checkpoint['iteration']))
        if config is None:
            config = checkpoint['config']
        else:
            config = json.load(open(config))
    else:
        checkpoint = None
        config = json.load(open(config))

    if gpu is None:
        config['cuda'] = False
    else:
        config['cuda'] = True
        config['gpu'] = gpu
    if thresh is not None:
        config['THRESH'] = thresh
        if verbose:
            print('Threshold at {}'.format(thresh))
    addDATASET = False
    if addToConfig is not None:
        for add in addToConfig:
            addTo = config
            if verbose:
                printM = 'added config['
            for i in range(len(add) - 2):
                addTo = addTo[add[i]]
                if verbose:
                    printM += add[i] + ']['
            value = add[-1]
            if value == "":
                value = None
            else:
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
            addTo[add[-2]] = value
            if verbose:
                printM += add[-2] + ']={}'.format(value)
                print(printM)
            if (add[-2] == 'useDetections'
                    or add[-2] == 'useDetect') and value != 'gt':
                addDATASET = True

    #config['data_loader']['batch_size']=math.ceil(config['data_loader']['batch_size']/2)

    config['data_loader']['shuffle'] = shuffle
    #config['data_loader']['rot']=False
    config['validation']['shuffle'] = shuffle
    config['data_loader']['eval'] = True
    config['validation']['eval'] = True
    #config['validation']

    if config['data_loader']['data_set_name'] == 'FormsDetect':
        config['data_loader']['batch_size'] = 1
        del config['data_loader']["crop_params"]
        config['data_loader']["rescale_range"] = config['validation'][
            "rescale_range"]

    #print(config['data_loader'])
    if setBatch is not None:
        config['data_loader']['batch_size'] = setBatch
        config['validation']['batch_size'] = setBatch
    batchSize = config['data_loader']['batch_size']
    if 'batch_size' in config['validation']:
        vBatchSize = config['validation']['batch_size']
    else:
        vBatchSize = batchSize
    if not test:
        data_loader, valid_data_loader = getDataLoader(config, 'train')
    else:
        valid_data_loader, data_loader = getDataLoader(config, 'test')

    if addDATASET:
        config['DATASET'] = valid_data_loader.dataset

    if checkpoint is not None:
        if 'state_dict' in checkpoint:
            model = eval(config['arch'])(config['model'])
            ##DEBUG
            if 'edgeFeaturizerConv.0.0.weight' in checkpoint['state_dict']:
                keys = list(checkpoint['state_dict'].keys())
                for key in keys:
                    if 'edge' in key:
                        newKey = key.replace('edge', 'rel')
                        checkpoint['state_dict'][newKey] = checkpoint[
                            'state_dict'][key]
                        del checkpoint['state_dict'][key]
            ##DEBUG
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model = checkpoint['model']
    else:
        model = eval(config['arch'])(config['model'])
    model.eval()
    if verbose:
        model.summary()

    if gpu is not None:
        model = model.to(gpu)
    else:
        model = model.cpu()

    metrics = [eval(metric) for metric in config['metrics']]

    #if "class" in config["trainer"]:
    #    trainer_class = config["trainer"]["class"]
    #else:
    #    trainer_class = "Trainer"

    #saveFunc = eval(trainer_class+'_printer')
    saveFunc = eval(config['data_loader']['data_set_name'] + '_printer')

    step = 5

    #numberOfImages = numberOfImages//config['data_loader']['batch_size']
    #print(len(data_loader))
    if data_loader is not None:
        train_iter = iter(data_loader)
    valid_iter = iter(valid_data_loader)

    with torch.no_grad():

        if index is None:

            if saveDir is not None:
                trainDir = os.path.join(saveDir, 'train_' + config['name'])
                validDir = os.path.join(saveDir, 'valid_' + config['name'])
                if not os.path.isdir(trainDir):
                    os.mkdir(trainDir)
                if not os.path.isdir(validDir):
                    os.mkdir(validDir)
            else:
                trainDir = None
                validDir = None

            val_metrics_sum = np.zeros(len(metrics))
            val_metrics_list = defaultdict(lambda: defaultdict(list))
            val_comb_metrics = defaultdict(list)

            #if numberOfImages==0:
            #    for i in range(len(valid_data_loader)):
            #        print('valid batch index: {}\{} (not save)'.format(i,len(valid_data_loader)),end='\r')
            #        instance=valid_iter.next()
            #        metricsO,_ = saveFunc(config,instance,model,gpu,metrics)

            #        if type(metricsO) == dict:
            #            for typ,typeLists in metricsO.items():
            #                if type(typeLists) == dict:
            #                    for name,lst in typeLists.items():
            #                        val_metrics_list[typ][name]+=lst
            #                        val_comb_metrics[typ]+=lst
            #                else:
            #                    if type(typeLists) is float or type(typeLists) is int:
            #                        typeLists = [typeLists]
            #                    val_comb_metrics[typ]+=typeLists
            #        else:
            #            val_metrics_sum += metricsO.sum(axis=0)/metricsO.shape[0]
            #else:

            ####
            if 'save_nns' in config:
                nns = []
            curVI = 0

            validName = 'valid' if not test else 'test'

            for index in range(0, numberOfImages, step * vBatchSize):

                for validIndex in range(index, index + step * vBatchSize,
                                        vBatchSize):
                    if validIndex / vBatchSize < len(valid_data_loader):
                        print('{} batch index: {}/{}'.format(
                            validName, validIndex / vBatchSize,
                            len(valid_data_loader)),
                              end='\r')
                        #data, target = valid_iter.next() #valid_data_loader[validIndex]
                        curVI += 1
                        #dataT  = _to_tensor(gpu,data)
                        #output = model(dataT)
                        #data = data.cpu().data.numpy()
                        #output = output.cpu().data.numpy()
                        #target = target.data.numpy()
                        #metricsO = _eval_metrics_ind(metrics,output, target)
                        metricsO, aux = saveFunc(config, valid_iter.next(),
                                                 model, gpu, metrics, validDir,
                                                 validIndex)
                        if type(metricsO) == dict:
                            for typ, typeLists in metricsO.items():
                                if type(typeLists) == dict:
                                    for name, lst in typeLists.items():
                                        val_metrics_list[typ][name] += lst
                                        val_comb_metrics[typ] += lst
                                else:
                                    if type(typeLists) is float or type(
                                            typeLists) is int:
                                        typeLists = [typeLists]
                                    val_comb_metrics[typ] += typeLists
                        else:
                            val_metrics_sum += metricsO.sum(
                                axis=0) / metricsO.shape[0]

                #if not test:
                #    for trainIndex in range(index,index+step*batchSize, batchSize):
                #        if trainIndex/batchSize < len(data_loader):
                #            print('train batch index: {}/{}'.format(trainIndex/batchSize,len(data_loader)),end='\r')
                #            #data, target = train_iter.next() #data_loader[trainIndex]
                #            #dataT = _to_tensor(gpu,data)
                #            #output = model(dataT)
                #            #data = data.cpu().data.numpy()
                #            #output = output.cpu().data.numpy()
                #            #target = target.data.numpy()
                #            #metricsO = _eval_metrics_ind(metrics,output, target)
                #            _,aux=saveFunc(config,train_iter.next(),model,gpu,metrics,trainDir,trainIndex)
                #            if 'save_nns' in config:
                #                nns+=aux[-1]

            #if gpu is not None or numberOfImages==0:
            try:
                for vi in range(curVI, len(valid_data_loader)):
                    if verbose > 1:
                        print('{} batch index: {}\{} (not save)'.format(
                            validName, vi, len(valid_data_loader)),
                              end='\r')
                    instance = valid_iter.next()
                    metricsO, _ = saveFunc(config, instance, model, gpu,
                                           metrics)
                    if type(metricsO) == dict:
                        for typ, typeLists in metricsO.items():
                            if type(typeLists) == dict:
                                for name, lst in typeLists.items():
                                    val_metrics_list[typ][name] += lst
                                    val_comb_metrics[typ] += lst
                            else:
                                if type(typeLists) is float or type(
                                        typeLists) is int:
                                    typeLists = [typeLists]
                                if val_comb_metrics is not None and typeLists is not None:
                                    val_comb_metrics[typ] += typeLists
                    else:
                        val_metrics_sum += metricsO.sum(
                            axis=0) / metricsO.shape[0]
            except StopIteration:
                print(
                    'ERROR: ran out of valid batches early. Expected {} more'.
                    format(len(valid_data_loader) - vi))
            ####

            val_metrics_sum /= len(valid_data_loader)
            print('{} metrics'.format(validName))
            for i in range(len(metrics)):
                print(metrics[i].__name__ + ': ' + str(val_metrics_sum[i]))
            for typ in val_comb_metrics:
                print('{} overall mean: {}, std {}'.format(
                    typ, np.mean(val_comb_metrics[typ], axis=0),
                    np.std(val_comb_metrics[typ], axis=0)))
                for name, typeLists in val_metrics_list[typ].items():
                    print('{} {} mean: {}, std {}'.format(
                        typ, name, np.mean(typeLists, axis=0),
                        np.std(typeLists, axis=0)))

            if 'save_nns' in config:
                import pickle
                pickle.dump(nns, open(config['save_nns'], 'wb'))

        elif type(index) == int:
            if index > 0:
                instances = train_iter
            else:
                index *= -1
                instances = valid_iter
            batchIndex = index // batchSize
            inBatchIndex = index % batchSize
            for i in range(batchIndex + 1):
                instance = instances.next()
            #data, target = data[inBatchIndex:inBatchIndex+1], target[inBatchIndex:inBatchIndex+1]
            #dataT = _to_tensor(gpu,data)
            #output = model(dataT)
            #data = data.cpu().data.numpy()
            #output = output.cpu().data.numpy()
            #target = target.data.numpy()
            #print (output.shape)
            #print ((output.min(), output.amin()))
            #print (target.shape)
            #print ((target.amin(), target.amin()))
            #metricsO = _eval_metrics_ind(metrics,output, target)
            saveFunc(config, instance, model, gpu, metrics, saveDir,
                     batchIndex * batchSize)
        else:
            for instance in data_loader:
                if index in instance['imgName']:
                    break
            if index not in instance['imgName']:
                for instance in valid_data_loader:
                    if index in instance['imgName']:
                        break
            if index in instance['imgName']:
                saveFunc(config, instance, model, gpu, metrics, saveDir, 0)
            else:
                print('{} not found! (on {})'.format(index,
                                                     instance['imgName']))
                print('{} not found! (on {})'.format(index,
                                                     instance['imgName']))
Ejemplo n.º 8
0
def main(resume,
         saveDir,
         numberOfImages,
         index,
         gpu=None,
         shuffle=False,
         setBatch=None,
         config=None,
         thresh=None,
         addToConfig=None,
         test=False,
         toEval=None,
         verbosity=2):
    np.random.seed(1234)
    torch.manual_seed(1234)
    if resume is not None:
        checkpoint = torch.load(resume,
                                map_location=lambda storage, location: storage)
        print('loaded iteration {}'.format(checkpoint['iteration']))
        loaded_iteration = checkpoint['iteration']
        if config is None:
            config = checkpoint['config']
        else:
            config = json.load(open(config))
        for key in config.keys():
            if type(config[key]) is dict:
                for key2 in config[key].keys():
                    if key2.startswith('pretrained'):
                        config[key][key2] = None
    else:
        checkpoint = None
        config = json.load(open(config))
        loaded_iteration = None
    config['optimizer_type'] = "none"
    config['trainer']['use_learning_schedule'] = False
    config['trainer']['swa'] = False
    if gpu is None:
        config['cuda'] = False
    else:
        config['cuda'] = True
        config['gpu'] = gpu
    if thresh is not None:
        config['THRESH'] = thresh
        print('Threshold at {}'.format(thresh))
    addDATASET = False
    if addToConfig is not None:
        for add in addToConfig:
            addTo = config
            printM = 'added config['
            for i in range(len(add) - 2):
                addTo = addTo[add[i]]
                printM += add[i] + ']['
            value = add[-1]
            if value == "":
                value = None
            elif value[0] == '[' and value[-1] == ']':
                value = value[1:-1].split('-')
            else:
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
            addTo[add[-2]] = value
            printM += add[-2] + ']={}'.format(value)
            print(printM)
            if (add[-2] == 'useDetections'
                    or add[-2] == 'useDetect') and value != 'gt':
                addDATASET = True

    #config['data_loader']['batch_size']=math.ceil(config['data_loader']['batch_size']/2)
    if 'save_spaced' in config:
        spaced = {}
        spaced_val = {}
        if toEval is None:
            toEval = ['spaced_label']
        elif 'spaced_label' not in toEval:
            toEval.append('spaced_label')
        config['data_loader']['batch_size'] = 1
        config['validation']['batch_size'] = 1
        if 'a_batch_size' in config['data_loader']:
            config['data_loader']['a_batch_size'] = 1
        if 'a_batch_size' in config['validation']:
            config['validation']['a_batch_size'] = 1

    config['data_loader']['shuffle'] = shuffle
    #config['data_loader']['rot']=False
    config['validation']['shuffle'] = shuffle
    config['data_loader']['eval'] = True
    config['validation']['eval'] = True
    #config['validation']

    if config['data_loader']['data_set_name'] == 'FormsDetect':
        config['data_loader']['batch_size'] = 1
        del config['data_loader']["crop_params"]
        config['data_loader']["rescale_range"] = config['validation'][
            "rescale_range"]

    #print(config['data_loader'])
    if setBatch is not None:
        config['data_loader']['batch_size'] = setBatch
        config['validation']['batch_size'] = setBatch
    batchSize = config['data_loader']['batch_size']
    if 'batch_size' in config['validation']:
        vBatchSize = config['validation']['batch_size']
    else:
        vBatchSize = batchSize
    if not test:
        data_loader, valid_data_loader = getDataLoader(config, 'train')
    else:
        valid_data_loader, data_loader = getDataLoader(config, 'test')

    if addDATASET:
        config['DATASET'] = valid_data_loader.dataset
    #ttt=FormsDetect(dirPath='/home/ubuntu/brian/data/forms',split='train',config={'crop_to_page':False,'rescale_range':[450,800],'crop_params':{"crop_size":512},'no_blanks':True, "only_types": ["text_start_gt"], 'cache_resized_images': True})
    #data_loader = torch.utils.data.DataLoader(ttt, batch_size=16, shuffle=False, num_workers=5, collate_fn=forms_detect.collate)
    #valid_data_loader = data_loader.split_validation()

    if checkpoint is not None:
        if 'state_dict' in checkpoint:
            model = eval(config['model']['arch'])(config['model'])
            if config['trainer']['class'] == 'HWRWithSynthTrainer':
                model = model.hwr
            if 'style' in config['model'] and 'lookup' in config['model'][
                    'style']:
                model.style_extractor.add_authors(
                    data_loader.dataset.authors)  ##HERE
            ##HACK fix
            keys = list(checkpoint['state_dict'].keys())
            for key in keys:
                if 'style_from_normal' in key:  #HACK
                    del checkpoint['state_dict'][key]
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model = checkpoint['model']
    else:
        model = eval(config['arch'])(config['model'])
    model.eval()
    if verbosity > 1:
        model.summary()

    if type(config['loss']) == dict:
        loss = {}  #[eval(l) for l in config['loss']]
        for name, l in config['loss'].items():
            loss[name] = eval(l)
    else:
        loss = eval(config['loss'])
    metrics = [eval(metric) for metric in config['metrics']]

    train_logger = Logger()
    trainerClass = eval(config['trainer']['class'])
    trainer = trainerClass(
        model,
        loss,
        metrics,
        resume=False,  #path
        config=config,
        data_loader=data_loader,
        valid_data_loader=valid_data_loader,
        train_logger=train_logger)
    #saveFunc = eval(trainer_class+'_printer')
    saveFunc = eval(config['data_loader']['data_set_name'] + '_eval')

    step = 5

    #numberOfImages = numberOfImages//config['data_loader']['batch_size']
    #print(len(data_loader))
    if data_loader is not None:
        train_iter = iter(data_loader)
    if valid_data_loader is not None:
        valid_iter = iter(valid_data_loader)

    with torch.no_grad():

        if index is None:

            if saveDir is not None:
                trainDir = os.path.join(saveDir, 'train_' + config['name'])
                validDir = os.path.join(saveDir, 'valid_' + config['name'])
                if not os.path.isdir(trainDir):
                    os.mkdir(trainDir)
                if not os.path.isdir(validDir):
                    os.mkdir(validDir)

                if loaded_iteration is not None:
                    with open(
                            os.path.join(
                                validDir,
                                'z_iter_{}.txt'.format(loaded_iteration)),
                            'w') as f:
                        f.write('{}'.format(loaded_iteration))
            else:
                trainDir = None
                validDir = None

            val_metrics_sum = np.zeros(len(metrics))
            val_metrics_list = defaultdict(lambda: defaultdict(list))
            val_comb_metrics = defaultdict(list)

            #if numberOfImages==0:
            #    for i in range(len(valid_data_loader)):
            #        print('valid batch index: {}\{} (not save)'.format(i,len(valid_data_loader)),end='\r')
            #        instance=valid_iter.next()
            #        metricsO,_ = saveFunc(config,instance,model,gpu,metrics)

            #        if type(metricsO) == dict:
            #            for typ,typeLists in metricsO.items():
            #                if type(typeLists) == dict:
            #                    for name,lst in typeLists.items():
            #                        val_metrics_list[typ][name]+=lst
            #                        val_comb_metrics[typ]+=lst
            #                else:
            #                    if type(typeLists) is float or type(typeLists) is int:
            #                        typeLists = [typeLists]
            #                    val_comb_metrics[typ]+=typeLists
            #        else:
            #            val_metrics_sum += metricsO.sum(axis=0)/metricsO.shape[0]
            #else:

            ####
            if 'save_spaced' in config:
                spaced = {}
                spaced_val = {}
                assert (config['data_loader']['batch_size'] == 1)
                assert (config['validation']['batch_size'] == 1)
                if 'a_batch_size' in config['data_loader']:
                    assert (config['data_loader']['a_batch_size'] == 1)
                if 'a_batch_size' in config['validation']:
                    assert (config['validation']['a_batch_size'] == 1)
            if 'save_nns' in config:
                nns = []
            if 'save_style' in config:
                if toEval is None:
                    toEval = []
                if 'style' not in toEval:
                    toEval.append('style')
                if 'author' not in toEval:
                    toEval.append('author')
                styles = []
                authors = []
                strings = []
                stylesVal = []
                authorsVal = []
                spacedVal = []
                stringsVal = []
                spaced = []

                doIds = config['data_loader'][
                    'data_set_name'] == 'StyleWordDataset'
                #doSpaced = not doIds#?
                doSpaced = 'doSpaced' in config
                if doSpaced:
                    if 'spaced_label' not in toEval:
                        toEval.append('spaced_label')
                    if 'gt' not in toEval:
                        toEval.append('gt')
                ids = []
                idsVal = []
                saveStyleEvery = config[
                    'saveStyleEvery'] if 'saveStyleEvery' in config else 5000
                saveStyleLoc = config['save_style']
                lastSlash = saveStyleLoc.rfind('/')
                if lastSlash >= 0:
                    saveStyleValLoc = saveStyleLoc[:lastSlash +
                                                   1] + 'val_' + saveStyleLoc[
                                                       lastSlash + 1:]
                else:
                    saveStyleValLoc = 'val_' + saveStyleLoc

            if 'save_preds' in config:
                to_save = []

            validName = 'valid' if not test else 'test'

            startBatch = config['startBatch'] if 'startBatch' in config else 0
            numberOfBatches = numberOfImages // batchSize
            if numberOfBatches == 0 and numberOfImages > 1:
                numberOfBatches = 1

            #for index in range(startIndex,numberOfImages,step*batchSize):
            batch = startBatch
            numberOfBatches = min(
                numberOfBatches, max(len(valid_data_loader), len(data_loader)))
            for batch in range(startBatch, numberOfBatches):

                #for validIndex in range(index,index+step*vBatchSize, vBatchSize):
                #for validBatch
                #if valyypidIndex/vBatchSize < len(valid_data_loader):
                if valid_data_loader is not None and batch < len(
                        valid_data_loader) and 'skip_valid' not in config:
                    print('{} batch index: {}/{}       '.format(
                        validName, batch, len(valid_data_loader)),
                          end='\r')
                    #data, target = valid_iter.next() #valid_data_loader[validIndex]
                    #dataT  = _to_tensor(gpu,data)
                    #output = model(dataT)
                    #data = data.cpu().data.numpy()
                    #output = output.cpu().data.numpy()
                    #target = target.data.numpy()
                    #metricsO = _eval_metrics_ind(metrics,output, target)
                    metricsO, aux = saveFunc(config,
                                             valid_iter.next(),
                                             trainer,
                                             metrics,
                                             validDir,
                                             batch * vBatchSize,
                                             toEval=toEval)
                    if type(metricsO) == dict:
                        for typ, typeLists in metricsO.items():
                            if type(typeLists) == dict:
                                for name, lst in typeLists.items():
                                    val_metrics_list[typ][name] += lst
                                    val_comb_metrics[typ] += lst
                            else:
                                if type(typeLists) is float or type(
                                        typeLists) is int:
                                    typeLists = [typeLists]
                                val_comb_metrics[typ] += typeLists
                    else:
                        val_metrics_sum += metricsO.sum(
                            axis=0) / metricsO.shape[0]
                    if 'save_spaced' in config:
                        spaced_val[aux['name'][0]] = aux['spaced_label'].cpu()
                    if 'save_style' in config:
                        stylesVal.append(aux['style'].cpu())
                        authorsVal += aux['author']
                        if doIds:
                            idsVal += aux['name']
                        elif doSpaced:
                            #spacedVal.append(aux[2])
                            spacedVal += aux['spaced_label']
                            stringsVal += aux['gt']
                        if batch > 0 and batch % saveStyleEvery == 0:
                            save_style(saveStyleValLoc, batch, stylesVal,
                                       authorsVal, idsVal, doIds, spacedVal,
                                       stringsVal, doSpaced)
                            stylesVal = []
                            authorsVal = []
                            idsVal = []
                            spacedVal = []
                            stringsVal = []

                if not test and 'skip_train' not in config:
                    #for trainIndex in range(index,index+step*batchSize, batchSize):
                    #    if trainIndex/batchSize < len(data_loader):
                    if batch < len(data_loader):
                        print('train batch index: {}/{}        '.format(
                            batch, len(data_loader)),
                              end='\r')
                        #data, target = train_iter.next() #data_loader[trainIndex]
                        #dataT = _to_tensor(gpu,data)
                        #output = model(dataT)
                        #data = data.cpu().data.numpy()
                        #output = output.cpu().data.numpy()
                        #target = target.data.numpy()
                        #metricsO = _eval_metrics_ind(metrics,output, target)
                        instance = train_iter.next()
                        _, aux = saveFunc(config,
                                          instance,
                                          trainer,
                                          metrics,
                                          trainDir,
                                          batch * batchSize,
                                          toEval=toEval)
                        if 'save_nns' in config:
                            nns += aux[-1]
                        if 'save_spaced' in config:
                            spaced[aux['name'][0]] = aux['spaced_label'].cpu()
                        if 'save_style' in config:
                            styles.append(aux['style'].cpu())
                            authors += aux['author']
                            if doIds:
                                ids += aux['name']
                            elif doSpaced:
                                #spaced.append(aux[2])
                                spaced += aux['spaced_label']
                                strings += aux['gt']
                            if batch > 0 and batch % saveStyleEvery == 0:
                                save_style(saveStyleLoc, batch, styles,
                                           authors, ids, doIds, spaced,
                                           strings, doSpaced)
                                styles = []
                                authors = []
                                ids = []
                                spaced = []
                                strings = []
                        if 'save_preds' in config:
                            for b in range(batchSize):
                                try:
                                    to_save.append([
                                        instance['name'][b], instance['gt'][b],
                                        aux['pred_str'][b], aux['cer'][b]
                                    ])
                                except IndexError:
                                    pass

            #if gpu is not None or numberOfImages==0:
            if 'save_preds' in config:
                with open(config['save_preds'], 'w') as f:
                    csvwriter = csv.writer(f,
                                           delimiter=',',
                                           quotechar='"',
                                           quoting=csv.QUOTE_MINIMAL)
                    for l in to_save:
                        csvwriter.writerow(l)
                print('wrote results to {}'.format(config['save_preds']))
            if saveDir is None:
                try:
                    if valid_data_loader is not None:
                        for vi in range(batch, len(valid_data_loader)):
                            #print('{} batch index: {}\{} (not save)   '.format(validName,vi,len(valid_data_loader)),end='\r')
                            instance = valid_iter.next()
                            metricsO, aux = saveFunc(config,
                                                     instance,
                                                     trainer,
                                                     metrics,
                                                     toEval=toEval)
                            if type(metricsO) == dict:
                                for typ, typeLists in metricsO.items():
                                    if type(typeLists) == dict:
                                        for name, lst in typeLists.items():
                                            val_metrics_list[typ][name] += lst
                                            val_comb_metrics[typ] += lst
                                    else:
                                        if type(typeLists) is float or type(
                                                typeLists) is int:
                                            typeLists = [typeLists]
                                        val_comb_metrics[typ] += typeLists
                            else:
                                val_metrics_sum += metricsO.sum(
                                    axis=0) / metricsO.shape[0]
                            if 'save_spaced' in config:
                                spaced_val[aux['name']
                                           [0]] = aux['spaced_label'].cpu()
                            if 'save_style' in config:
                                stylesVal.append(aux['style'].cpu())
                                authorsVal += aux['author']
                                if doIds:
                                    idsVal += aux['name']
                                elif doSpaced:
                                    #spacedVal.append(aux[2])
                                    spacedVal += aux['spaced_label']
                                    stringsVal += aux['gt']
                                if vi > 0 and vi % saveStyleEvery == 0:
                                    save_style(saveStyleValLoc, vi, stylesVal,
                                               authorsVal, idsVal, doIds,
                                               spacedVal, stringsVal, doSpaced)
                                    stylesVal = []
                                    authorsVal = []
                                    idsVal = []
                                    spacedVal = []
                                    stringsVal = []
                except StopIteration:
                    print(
                        'ERROR: ran out of valid batches early. Expected {} more'
                        .format(len(valid_data_loader) - vi))
            ####
            if valid_data_loader is not None:
                val_metrics_sum /= len(valid_data_loader)
            print('{} metrics'.format(validName))
            for i in range(len(metrics)):
                print(metrics[i].__name__ + ': ' + str(val_metrics_sum[i]))
            for typ in val_comb_metrics:
                print('{} overall mean: {}, std {}'.format(
                    typ, np.mean(val_comb_metrics[typ], axis=0),
                    np.std(val_comb_metrics[typ], axis=0)))
                for name, typeLists in val_metrics_list[typ].items():
                    print('{} {} mean: {}, std {}'.format(
                        typ, name, np.mean(typeLists, axis=0),
                        np.std(typeLists, axis=0)))

            if 'save_nns' in config:
                pickle.dump(nns, open(config['save_nns'], 'wb'))
            if 'save_spaced' in config:
                #import pdb;pdb.set_trace()
                #spaced = torch.cat(spaced,dim=1).numpy()
                #spaced_val = torch.cat(spaced_val,dim=1).numpy()
                saveSpacedLoc = config['save_spaced']
                lastSlash = saveSpacedLoc.rfind('/')
                if lastSlash >= 0:
                    saveSpacedValLoc = saveSpacedLoc[:lastSlash +
                                                     1] + 'val_' + saveSpacedLoc[
                                                         lastSlash + 1:]
                else:
                    saveSpacedValLoc = 'val_' + saveSpacedLoc
                with open(saveSpacedLoc, 'wb') as f:
                    pickle.dump(spaced, f)

                with open(saveSpacedValLoc, 'wb') as f:
                    pickle.dump(spaced_val, f)

            if 'save_style' in config:
                if len(styles) > 0:
                    assert (not doSpaced)
                    save_style(saveStyleLoc, len(data_loader), styles, authors,
                               ids, doIds)
                if len(stylesVal) > 0:
                    save_style(saveStyleValLoc, len(valid_data_loader),
                               stylesVal, authorsVal, idsVal, doIds)
        elif type(index) == int:
            if index > 0:
                instances = train_iter
            else:
                index *= -1
                instances = valid_iter
            batchIndex = index // batchSize
            inBatchIndex = index % batchSize
            for i in range(batchIndex + 1):
                instance = instances.next()
            #data, target = data[inBatchIndex:inBatchIndex+1], target[inBatchIndex:inBatchIndex+1]
            #dataT = _to_tensor(gpu,data)
            #output = model(dataT)
            #data = data.cpu().data.numpy()
            #output = output.cpu().data.numpy()
            #target = target.data.numpy()
            #print (output.shape)
            #print ((output.min(), output.amin()))
            #print (target.shape)
            #print ((target.amin(), target.amin()))
            #metricsO = _eval_metrics_ind(metrics,output, target)
            saveFunc(config,
                     instance,
                     model,
                     gpu,
                     metrics,
                     saveDir,
                     batchIndex * batchSize,
                     toEval=toEval)
        else:
            for instance in data_loader:
                if index in instance['imgName']:
                    break
            if index not in instance['imgName']:
                for instance in valid_data_loader:
                    if index in instance['imgName']:
                        break
            if index in instance['imgName']:
                saveFunc(config,
                         instance,
                         model,
                         gpu,
                         metrics,
                         saveDir,
                         0,
                         toEval=toEval)
            else:
                print('{} not found! (on {})'.format(index,
                                                     instance['imgName']))
                print('{} not found! (on {})'.format(index,
                                                     instance['imgName']))
def main(resume,
         saveDir,
         index,
         gpu=None,
         shuffle=False,
         setBatch=None,
         config=None,
         addToConfig=None,
         test=False,
         verbosity=2,
         transform_style=False):
    assert (saveDir is not None)
    np.random.seed(1234)
    torch.manual_seed(1234)
    if resume is not None:
        checkpoint = torch.load(resume,
                                map_location=lambda storage, location: storage)
        print('loaded iteration {}'.format(checkpoint['iteration']))
        loaded_iteration = checkpoint['iteration']
        if config is None:
            config = checkpoint['config']
        else:
            config = json.load(open(config))
        for key in config.keys():
            if type(config[key]) is dict:
                for key2 in config[key].keys():
                    if key2.startswith('pretrained'):
                        config[key][key2] = None
    else:
        checkpoint = None
        config = json.load(open(config))
        loaded_iteration = None

    train_loc = os.path.join(saveDir,
                             'train_styles_{}.pkl'.format(loaded_iteration))
    if not test:
        val_loc = os.path.join(saveDir,
                               'val_styles_{}.pkl'.format(loaded_iteration))
    else:
        val_loc = os.path.join(saveDir,
                               'test_styles_{}.pkl'.format(loaded_iteration))

    config['optimizer_type'] = "none"
    config['trainer']['use_learning_schedule'] = False
    config['trainer']['swa'] = False
    if gpu is None:
        config['cuda'] = False
    else:
        config['cuda'] = True
        config['gpu'] = gpu
    addDATASET = False
    if addToConfig is not None:
        for add in addToConfig:
            addTo = config
            printM = 'added config['
            for i in range(len(add) - 2):
                addTo = addTo[add[i]]
                printM += add[i] + ']['
            value = add[-1]
            if value == "":
                value = None
            elif value[0] == '[' and value[-1] == ']':
                value = value[1:-1].split('-')
            else:
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
            addTo[add[-2]] = value
            printM += add[-2] + ']={}'.format(value)
            print(printM)
            if (add[-2] == 'useDetections'
                    or add[-2] == 'useDetect') and value != 'gt':
                addDATASET = True

    config['data_loader']['shuffle'] = shuffle
    #config['data_loader']['rot']=False
    config['validation']['shuffle'] = shuffle
    config['data_loader']['eval'] = True
    config['validation']['eval'] = True
    #config['validation']

    if config['data_loader']['data_set_name'] == 'FormsDetect':
        config['data_loader']['batch_size'] = 1
        del config['data_loader']["crop_params"]
        config['data_loader']["rescale_range"] = config['validation'][
            "rescale_range"]

    #print(config['data_loader'])
    if setBatch is not None:
        config['data_loader']['batch_size'] = setBatch
        config['validation']['batch_size'] = setBatch
    batchSize = config['data_loader']['batch_size']
    if 'batch_size' in config['validation']:
        vBatchSize = config['validation']['batch_size']
    else:
        vBatchSize = batchSize
    if not test:
        data_loader, valid_data_loader = getDataLoader(config, 'train')
    else:
        valid_data_loader, data_loader = getDataLoader(config, 'test')

    if addDATASET:
        config['DATASET'] = valid_data_loader.dataset
    #ttt=FormsDetect(dirPath='/home/ubuntu/brian/data/forms',split='train',config={'crop_to_page':False,'rescale_range':[450,800],'crop_params':{"crop_size":512},'no_blanks':True, "only_types": ["text_start_gt"], 'cache_resized_images': True})
    #data_loader = torch.utils.data.DataLoader(ttt, batch_size=16, shuffle=False, num_workers=5, collate_fn=forms_detect.collate)
    #valid_data_loader = data_loader.split_validation()

    if checkpoint is not None:
        if 'state_dict' in checkpoint:
            model = eval(config['arch'])(config['model'])
            if config['trainer']['class'] == 'HWRWithSynthTrainer':
                model = model.hwr
            if 'style' in config['model'] and 'lookup' in config['model'][
                    'style']:
                model.style_extractor.add_authors(
                    data_loader.dataset.authors)  ##HERE
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model = checkpoint['model']
    else:
        model = eval(config['arch'])(config['model'])
    model.eval()
    if verbosity > 1:
        model.summary()

    if type(config['loss']) == dict:
        loss = {}  #[eval(l) for l in config['loss']]
        for name, l in config['loss'].items():
            loss[name] = eval(l)
    else:
        loss = eval(config['loss'])
    metrics = [eval(metric) for metric in config['metrics']]

    train_logger = Logger()
    trainerClass = eval(config['trainer']['class'])
    trainer = trainerClass(
        model,
        loss,
        metrics,
        resume=False,  #path
        config=config,
        data_loader=data_loader,
        valid_data_loader=valid_data_loader,
        train_logger=train_logger)
    #saveFunc = eval(trainer_class+'_printer')
    saveFunc = eval(config['data_loader']['data_set_name'] + '_eval')

    step = 5

    if data_loader is not None:
        train_iter = iter(data_loader)
    if valid_data_loader is not None:
        valid_iter = iter(valid_data_loader)

    with torch.no_grad():

        if index is None:

            val_metrics_sum = np.zeros(len(metrics))
            val_metrics_list = defaultdict(lambda: defaultdict(list))
            val_comb_metrics = defaultdict(list)

            validName = 'valid' if not test else 'test'
            charSpec = trainer.model.char_style_dim > 0

            train_styles = []
            train_authors = []
            if not test:
                for i, instance in enumerate(data_loader):
                    print('train: {}/{}       '.format(i, len(data_loader)),
                          end='\r')
                    image, label = trainer._to_tensor(instance)
                    batch_size = label.size(1)
                    label_lengths = instance['label_lengths']
                    a_batch_size = trainer.a_batch_size if 'a_batch_size' in instance else None

                    style = trainer.model.extract_style(
                        image, label, a_batch_size)

                    if transform_style:
                        style = trainer.model.generator.style_emb(style)
                    if charSpec:
                        for b in range(batch_size):
                            train_styles.append(
                                (style[0][b].cpu(), style[1][b].cpu(),
                                 style[2][b].cpu()))
                    else:
                        train_styles.append(style.cpu())
                    train_authors += instance['author']

                    trainer.model.spaced_label = None
                    trainer.model.mask = None
                    trainer.model.gen_mask = None
                    trainer.model.top_and_bottom = None
                    trainer.model.counts = None
                    trainer.model.pred = None
                    trainer.model.spacing_pred = None
                    trainer.model.mask_pred = None
                    trainer.model.gen_spaced = None
                    trainer.model.spaced_style = None
                    trainer.model.mu = None
                    trainer.model.sigma = None
                if charSpec:
                    train_styles = [(s[0].numpy(), s[1].numpy(), s[2].numpy())
                                    for s in train_styles]
                else:
                    train_styles = torch.cat(train_styles, dim=0).numpy()
                train_authors = np.array(train_authors)
                pickle.dump({
                    'styles': train_styles,
                    'authors': train_authors
                }, open(train_loc, 'wb'))
                print('saved {}'.format(train_loc))

            val_styles = []
            val_authors = []
            for i, instance in enumerate(valid_data_loader):
                print('{}: {}/{}       '.format(validName, i,
                                                len(valid_data_loader)),
                      end='\r')
                image, label = trainer._to_tensor(instance)
                batch_size = label.size(1)
                label_lengths = instance['label_lengths']
                a_batch_size = trainer.a_batch_size if 'a_batch_size' in instance else None

                style = trainer.model.extract_style(image, label, a_batch_size)

                if transform_style:
                    style = trainer.model.generator.style_emb(style)
                if charSpec:
                    for b in range(batch_size):
                        val_styles.append(
                            (style[0][b].cpu(), style[1][b].cpu(),
                             style[2][b].cpu()))
                else:
                    val_styles.append(style.cpu())

                val_authors += instance['author']

                trainer.model.spaced_label = None
                trainer.model.mask = None
                trainer.model.gen_mask = None
                trainer.model.top_and_bottom = None
                trainer.model.counts = None
                trainer.model.pred = None
                trainer.model.spacing_pred = None
                trainer.model.mask_pred = None
                trainer.model.gen_spaced = None
                trainer.model.spaced_style = None
                trainer.model.mu = None
                trainer.model.sigma = None
            if charSpec:
                val_styles = [(s[0].numpy(), s[1].numpy(), s[2].numpy())
                              for s in val_styles]
                assert (len(val_styles) == len(val_authors))
            else:
                val_styles = torch.cat(val_styles, dim=0).numpy()
            val_authors = np.array(val_authors)
            pickle.dump({
                'styles': val_styles,
                'authors': val_authors
            }, open(val_loc, 'wb'))
            print('saved {}'.format(val_loc))
Ejemplo n.º 10
0
def main(rank, config, resume, world_size=None):
    if rank is not None:  #multiprocessing
        #print('Process {} can see these GPUs:'.format(rank,os.environ['CUDA_VISIBLE_DEVICES']))
        if 'distributed' in config:
            print('env NCCL_SOCKET_IFNAME: {}'.format(
                os.environ['NCCL_SOCKET_IFNAME']))
            print('{} calling dist.init_process_group()'.format(rank))
            os.environ['CUDA_VISIBLE_DEVICES'] = '0'
            dist.init_process_group(
                "nccl",
                init_method='file:///fslhome/brianld/job_comm/{}'.format(
                    config['name']),
                rank=rank,
                world_size=world_size)
            print('{} finished dist.init_process_group()'.format(rank))
        else:
            dist.init_process_group("gloo", rank=rank, world_size=world_size)

    #np.random.seed(1234) I don't have a way of restarting the DataLoader at the same place, so this makes it totaly random
    train_logger = Logger()

    split = config['split'] if 'split' in config else 'train'
    data_loader, valid_data_loader = getDataLoader(config, split, rank,
                                                   world_size)
    #valid_data_loader = data_loader.split_validation()

    model = eval(config['arch'])(config['model'])
    model.summary()

    if type(config['loss']) == dict:
        loss = {}  #[eval(l) for l in config['loss']]
        for name, l in config['loss'].items():
            loss[name] = eval(l)
    else:
        loss = eval(config['loss'])
    if type(config['metrics']) == dict:
        metrics = {}
        for name, m in config['metrics'].items():
            metrics[name] = [eval(metric) for metric in m]
    else:
        metrics = [eval(metric) for metric in config['metrics']]

    if 'class' in config['trainer']:
        trainerClass = eval(config['trainer']['class'])
    else:
        trainerClass = Trainer
    trainer = trainerClass(model,
                           loss,
                           metrics,
                           resume=resume,
                           config=config,
                           data_loader=data_loader,
                           valid_data_loader=valid_data_loader,
                           train_logger=train_logger)

    name = config['name']
    supercomputer = config[
        'super_computer'] if 'super_computer' in config else False

    if rank is not None and rank != 0:
        trainer.side_process = rank  #this tells the trainer not to log or validate on this thread
    else:
        trainer.finishSetup()

        def handleSIGINT(sig, frame):
            trainer.save()
            sys.exit(0)

        signal.signal(signal.SIGINT, handleSIGINT)

    print("Begin training")
    #warnings.filterwarnings("error")
    trainer.train()
Ejemplo n.º 11
0
def main(resume,saveDir,numberOfImages,index,gpu=None, shuffle=False, setBatch=None, config=None, thresh=None, addToConfig=None, test=False, toEval=None,verbosity=2, do_train=False, use_train_model=False):
    np.random.seed(1234)
    torch.manual_seed(1234)
    if resume is not None:
        checkpoint = torch.load(resume, map_location=lambda storage, location: storage)
        print('loaded {} iteration {}'.format(checkpoint['config']['name'],checkpoint['iteration']))
        if config is None:
            config = checkpoint['config']
        else:
            config = json.load(open(config))
        for key in config.keys():
            if 'pretrained' in key:
                config[key]=None
    else:
        checkpoint = None
        config = json.load(open(config))
    config['optimizer_type']="none"
    config['trainer']['use_learning_schedule']=False
    config['trainer']['swa']=False
    if gpu is None:
        config['cuda']=False
    else:
        config['cuda']=True
        config['gpu']=gpu
    if thresh is not None:
        config['THRESH'] = thresh
        print('Threshold at {}'.format(thresh))
    config['model']['max_graph_size']=750
    config['model']['max_graph_cand']=700
    config['data_loader']['pixel_count_thresh']=900000000000
    config['data_loader']['max_dim_thresh']=999999999
    addDATASET=False
    if addToConfig is not None:
        for add in addToConfig:
            addTo=config
            printM='added config['
            for i in range(len(add)-2):
                try:
                    indName = int(add[i])
                except ValueError:
                    indName = add[i]
                addTo = addTo[indName]
                printM+=add[i]+']['
            value = add[-1]
            if value=="":
                value=None
            elif value[0]=='[' and value[-1]==']':
                value = value[1:-1].split('-')
            else:
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        if value == 'None':
                            value=None
            addTo[add[-2]] = value
            printM+=add[-2]+']={}'.format(value)
            print(printM)
            #if (add[-2]=='useDetections' or add[-2]=='useDetect') and 'gt' not in value:
            #    addDATASET=True
        
    #config['data_loader']['batch_size']=math.ceil(config['data_loader']['batch_size']/2)
    if 'save_spaced' in config:
        spaced={}
        spaced_val={}
        config['data_loader']['batch_size']=1
        config['validation']['batch_size']=1
        if 'a_batch_size' in config['data_loader']:
            config['data_loader']['a_batch_size']=1
        if 'a_batch_size' in config['validation']:
            config['validation']['a_batch_size']=1
    
    config['data_loader']['shuffle']=shuffle
    #config['data_loader']['rot']=False
    config['validation']['shuffle']=shuffle
    config['data_loader']['eval']=True
    config['validation']['eval']=True
    #config['validation']

    if config['data_loader']['data_set_name']=='FormsDetect':
        config['data_loader']['batch_size']=1
        del config['data_loader']["crop_params"]
        config['data_loader']["rescale_range"]= config['validation']["rescale_range"]

    #print(config['data_loader'])
    if setBatch is not None:
        config['data_loader']['batch_size']=setBatch
        config['validation']['batch_size']=setBatch
    batchSize = config['data_loader']['batch_size']
    if 'batch_size' in config['validation']:
        vBatchSize = config['validation']['batch_size']
    else:
        vBatchSize = batchSize
    if not test:
        data_loader, valid_data_loader = getDataLoader(config,'train')
    else:
        valid_data_loader, data_loader = getDataLoader(config,'test')
        data_loader = valid_data_loader

    if addDATASET:
        config['DATASET']=valid_data_loader.dataset
    #ttt=FormsDetect(dirPath='/home/ubuntu/brian/data/forms',split='train',config={'crop_to_page':False,'rescale_range':[450,800],'crop_params':{"crop_size":512},'no_blanks':True, "only_types": ["text_start_gt"], 'cache_resized_images': True})
    #data_loader = torch.utils.data.DataLoader(ttt, batch_size=16, shuffle=False, num_workers=5, collate_fn=forms_detect.collate)
    #valid_data_loader = data_loader.split_validation()

    if checkpoint is not None:
        if 'swa_state_dict' in checkpoint and checkpoint['iteration']>config['trainer']['swa_start']:
            model = eval(config['arch'])(config['model'])
            if 'style' in config['model'] and 'lookup' in config['model']['style']:
                model.style_extractor.add_authors(data_loader.dataset.authors) ##HERE
            #just strip off the 'module.' tag. I DON'T KNOW IF THIS WILL WORK PROPERLY WITH BATCHNORM
            new_state_dict = {key[7:]:value for key,value in checkpoint['swa_state_dict'].items() if key.startswith('module.')}
            model.load_state_dict(new_state_dict)
            print('Successfully loaded SWA model')
        elif 'state_dict' in checkpoint:
            model = eval(config['arch'])(config['model'])
            if 'style' in config['model'] and 'lookup' in config['model']['style']:
                model.style_extractor.add_authors(data_loader.dataset.authors) ##HERE
            model.load_state_dict(checkpoint['state_dict'])
        elif 'swa_model' in checkpoint:
            model = checkpoint['swa_model']
        else:
            model = checkpoint['model']
    else:
        model = eval(config['arch'])(config['model'])

    if use_train_model:
        model.train()
    else:
        model.eval()
    if verbosity>1:
        model.summary()
    else:
        try:
            print('model param counts: {}'.format(model.num_params()))
        except torch.nn.modules.module.ModuleAttributeError:
            pass

    if type(config['loss'])==dict: 
        loss={}#[eval(l) for l in config['loss']]
        for name,l in config['loss'].items():
            loss[name]=eval(l)
    else:   
        loss = eval(config['loss'])
    metrics = [eval(metric) for metric in config['metrics']]


    train_logger = Logger()
    trainerClass = eval(config['trainer']['class'])
    trainer = trainerClass(model, loss, metrics,
                      resume=False, #path
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      train_logger=train_logger)
    trainer.save_images_every=-1
    #saveFunc = eval(trainer_class+'_printer')
    saveFunc = eval(config['data_loader']['data_set_name']+'_eval')

    do_saliency_map =  config['saliency'] if 'saliency' in config else False
    do_graph_check_map =  config['graph_check'] if 'graph_check' in config else False
    if do_saliency_map:
        trainer.saliency_model = SimpleFullGradMod(trainer.model)
    if do_graph_check_map:
        trainer.graph_check_model = GraphChecker(trainer.model)

    step=5

    #numberOfImages = numberOfImages//config['data_loader']['batch_size']
    #print(len(data_loader))
    if data_loader is not None:
        train_iter = iter(data_loader)
    valid_iter = iter(valid_data_loader)

    #print("WARNING GRAD ENABLED")
    with torch.no_grad():
        if index is None:


            if saveDir is not None:
                trainDir = os.path.join(saveDir,'train_'+config['name'])
                validDir = os.path.join(saveDir,'valid_'+config['name'])
                if not os.path.isdir(trainDir):
                    os.mkdir(trainDir)
                if not os.path.isdir(validDir):
                    os.mkdir(validDir)
            else:
                trainDir=None
                validDir=None

            val_metrics_sum = np.zeros(len(metrics))
            val_metrics_list = defaultdict(lambda: defaultdict(list))
            val_comb_metrics = defaultdict(list)

            #if numberOfImages==0:
            #    for i in range(len(valid_data_loader)):
            #        print('valid batch index: {}\{} (not save)'.format(i,len(valid_data_loader)),end='\r')
            #        instance=valid_iter.next()
            #        metricsO,_ = saveFunc(config,instance,model,gpu,metrics)

            #        if type(metricsO) == dict:
            #            for typ,typeLists in metricsO.items():
            #                if type(typeLists) == dict:
            #                    for name,lst in typeLists.items():
            #                        val_metrics_list[typ][name]+=lst
            #                        val_comb_metrics[typ]+=lst
            #                else:
            #                    if type(typeLists) is float or type(typeLists) is int:
            #                        typeLists = [typeLists]
            #                    val_comb_metrics[typ]+=typeLists
            #        else:
            #            val_metrics_sum += metricsO.sum(axis=0)/metricsO.shape[0]
            #else:

            ####
            if 'save_spaced' in config:
                spaced={}
                spaced_val={}
                assert(config['data_loader']['batch_size']==1)
                assert(config['validation']['batch_size']==1)
                if 'a_batch_size' in config['data_loader']:
                    assert(config['data_loader']['a_batch_size']==1)
                if 'a_batch_size' in config['validation']:
                    assert(config['validation']['a_batch_size']==1)
            if 'save_nns' in config:
                nns=[]
            if 'save_style' in config:
                if toEval is None:
                    toEval =[]
                if 'style' not in toEval:
                    toEval.append('style')
                if 'author' not in toEval:
                    toEval.append('author')
                styles=[]
                authors=[]
                strings=[]
                stylesVal=[]
                authorsVal=[]
                spacedVal=[]
                stringsVal=[]
                
                doIds = config['data_loader']['data_set_name']=='StyleWordDataset'
                #doSpaced = not doIds#?
                doSpaced = 'doSpaced' in config
                if doSpaced:
                    if 'spaced_label' not in toEval:
                        toEval.append('spaced_label')
                    if 'gt' not in toEval:
                        toEval.append('gt')
                ids=[]
                idsVal=[]
                saveStyleEvery=config['saveStyleEvery'] if 'saveStyleEvery' in config else 5000
                saveStyleLoc = config['save_style']
                lastSlash = saveStyleLoc.rfind('/')
                if lastSlash>=0:
                    saveStyleValLoc = saveStyleLoc[:lastSlash+1]+'val_'+saveStyleLoc[lastSlash+1:]
                else:
                    saveStyleValLoc = 'val_'+saveStyleLoc

            validName='valid' if not test else 'test'

            startBatch = config['startBatch'] if 'startBatch' in config else 0
            numberOfBatches = numberOfImages//batchSize
            if numberOfBatches==0 and numberOfImages>1:
                numberOfBatches = 1

            #for index in range(startIndex,numberOfImages,step*batchSize):
            batch = startBatch
            for batch in range(startBatch,numberOfBatches):
            
                #for validIndex in range(index,index+step*vBatchSize, vBatchSize):
                #for validBatch
                    #if valyypidIndex/vBatchSize < len(valid_data_loader):
                if batch < len(valid_data_loader) and not do_train:
                        if verbosity>0:
                            print('{} batch index: {}/{}       '.format(validName,batch,len(valid_data_loader)),end='\r')
                        #data, target = valid_iter.next() #valid_data_loader[validIndex]
                        #dataT  = _to_tensor(gpu,data)
                        #output = model(dataT)
                        #data = data.cpu().data.numpy()
                        #output = output.cpu().data.numpy()
                        #target = target.data.numpy()
                        #metricsO = _eval_metrics_ind(metrics,output, target)
                        metricsO,aux = saveFunc(config,valid_iter.next(),trainer,metrics,validDir,batch*vBatchSize,toEval=toEval)
                        if type(metricsO) == dict:
                            for typ,typeLists in metricsO.items():
                                if type(typeLists) == dict:
                                    for name,lst in typeLists.items():
                                        val_metrics_list[typ][name]+=lst
                                        val_comb_metrics[typ]+=lst
                                else:
                                    if type(typeLists) is float or type(typeLists) is int:
                                        typeLists = [typeLists]
                                    if type(typeLists) is np.ndarray:
                                        val_comb_metrics[typ].append(typeLists)
                                    else:
                                        val_comb_metrics[typ]+=typeLists
                        else:
                            val_metrics_sum += metricsO.sum(axis=0)/metricsO.shape[0]
                        if 'save_spaced' in config:
                            spaced_val[aux['name'][0]] = aux['spaced_label'].cpu().numpy()
                        if 'save_style' in config:
                            stylesVal.append(aux['style'])
                            authorsVal+=aux['authors']
                            if doIds:
                                idsVal+=aux['name']
                            elif doSpaced:
                                #spacedVal.append(aux[2])
                                spacedVal+=aux['spaced_label']
                                stringsVal+=aux['gt']
                            if batch>0 and batch%saveStyleEvery==0:
                                save_style(saveStyleValLoc,batch,stylesVal,authorsVal,idsVal,doIds, spacedVal,stringsVal, doSpaced)
                                stylesVal=[]
                                authorsVal=[]
                                idsVal=[]
                                spacedVal=[]
                                stringsVal=[]
                            
                    
                if not test and do_train:
                    #for trainIndex in range(index,index+step*batchSize, batchSize):
                    #    if trainIndex/batchSize < len(data_loader):
                    if batch < len(data_loader):
                            if verbosity>0:
                                print('train batch index: {}/{}        '.format(batch,len(data_loader)),end='\r')
                            #data, target = train_iter.next() #data_loader[trainIndex]
                            #dataT = _to_tensor(gpu,data)
                            #output = model(dataT)
                            #data = data.cpu().data.numpy()
                            #output = output.cpu().data.numpy()
                            #target = target.data.numpy()
                            #metricsO = _eval_metrics_ind(metrics,output, target)
                            _,aux=saveFunc(config,train_iter.next(),trainer,metrics,trainDir,batch*batchSize,toEval=toEval)
                            if 'save_nns' in config:
                                nns+=aux[-1]
                            if 'save_spaced' in config:
                                spaced[aux['name'][0]] = aux['spaced_label'].cpu().numpy()
                            if 'save_style' in config:
                                styles.append(aux['style'])
                                authors+=aux['author']
                                if doIds:
                                    ids+=aux['name']
                                elif doSpaced:
                                    #spaced.append(aux[2])
                                    spaced+=aux['spaced_label']
                                    strings+=aux['gt']
                                if batch>0 and batch%saveStyleEvery==0:
                                    save_style(saveStyleLoc,batch,styles,authors,ids,doIds,spaced,strings,doSpaced)
                                    styles=[]
                                    authors=[]
                                    ids=[]
                                    spaced=[]
                                    strings=[]

            #if gpu is not None or numberOfImages==0:
            try:
                for vi in range(batch,len(valid_data_loader)):
                    if verbosity>0:
                        print('{} batch index: {}\{} (not save)   '.format(validName,vi,len(valid_data_loader)),end='\r')
                    instance = valid_iter.next()
                    metricsO,aux = saveFunc(config,instance,trainer,metrics,toEval=toEval)
                    if type(metricsO) == dict:
                        for typ,typeLists in metricsO.items():
                            if type(typeLists) == dict:
                                for name,lst in typeLists.items():
                                    val_metrics_list[typ][name]+=lst
                                    val_comb_metrics[typ]+=lst
                            elif typeLists is not None:
                                if type(typeLists) is float or type(typeLists) is int:
                                    typeLists = [typeLists]
                                if type(typeLists) is np.ndarray:
                                    val_comb_metrics[typ].append(typeLists)
                                else:
                                    val_comb_metrics[typ]+=typeLists
                    else:
                        val_metrics_sum += metricsO.sum(axis=0)/metricsO.shape[0]
                    if 'save_spaced' in config:
                        spaced_val[aux['name'][0]] = aux['spaced_label'].cpu().numpy()
                    if 'save_style' in config:
                        stylesVal.append(aux['style'])
                        authorsVal+=aux['author']
                        if doIds:
                            idsVal+=aux['name']
                        elif doSpaced:
                            #spacedVal.append(aux[2])
                            spacedVal+=aux['spaced_label']
                            stringsVal+=aux['gt']
                        if vi>0 and vi%saveStyleEvery==0:
                            save_style(saveStyleValLoc,vi,stylesVal,authorsVal,idsVal,doIds,spacedVal,stringsVal,doSpaced)
                            stylesVal=[]
                            authorsVal=[]
                            idsVal=[]
                            spacedVal=[]
                            stringsVal=[]
            except StopIteration:
                print('ERROR: ran out of valid batches early. Expected {} more'.format(len(valid_data_loader)-vi))
            ####

            with warnings.catch_warnings():   
                warnings.simplefilter('error')
                val_metrics_sum /= len(valid_data_loader)
                BROS_prec=None
                rel_BROS_TP=None
                group_TP = None
                print('{} metrics'.format(validName))
                for i in range(len(metrics)):
                    print(metrics[i].__name__ + ': '+str(val_metrics_sum[i]))
                for typ in val_comb_metrics:
                    if 'final_rel_XX_predCount'==typ:
                        rel_pred_count = sum(val_comb_metrics[typ])
                    elif 'final_rel_XX_gtCount'==typ:
                        rel_gt_count = sum(val_comb_metrics[typ])
                    elif 'final_rel_XX_strict_TP'==typ:
                        rel_strict_TP = sum(val_comb_metrics[typ])
                    elif 'final_rel_XX_BROS_TP'==typ:
                        rel_BROS_TP = sum(val_comb_metrics[typ])
                    elif 'final_group_XX_TP'==typ:
                        group_TP = sum(val_comb_metrics[typ])
                    elif 'final_group_XX_gtCount'==typ:
                        group_gt_count = sum(val_comb_metrics[typ])
                    elif 'final_group_XX_predCount'==typ:
                        group_pred_count = sum(val_comb_metrics[typ])
                    elif 'ED_TP_XX'==typ:
                        group_TP=sum(val_comb_metrics[typ])
                    elif 'ED_true_count_XX'==typ:
                        group_gt_count=sum(val_comb_metrics[typ])
                    elif 'ED_pred_count_XX'==typ:
                        group_pred_count=sum(val_comb_metrics[typ])

                    else:
                        assert 'XX' not in typ
                        if 'final_rel_BROS_prec'==typ:
                            BROS_prec = np.mean(val_comb_metrics[typ],axis=0)
                        elif 'final_rel_BROS_recall'==typ:
                            BROS_recall = np.mean(val_comb_metrics[typ],axis=0)
                        if 'final_rel_BROS_Fm'==typ:
                            BROS_Fm = np.mean(val_comb_metrics[typ],axis=0)
                        try:
                            print('{} overall mean: {}, std {}'.format(typ,np.mean(val_comb_metrics[typ],axis=0), np.std(val_comb_metrics[typ],axis=0)))
                            for name, typeLists in val_metrics_list[typ].items():
                                print('{} {} mean: {}, std {}'.format(typ,name,np.mean(typeLists,axis=0),np.std(typeLists,axis=0)))
                        except e:
                            print('ERROR on {}: {}'.format(typ,e))
                            print('{}'.format(val_comb_metrics[typ]))

                if BROS_prec is not None:
                    print('----PER DOCUMENT------')
                    print('BROS relationship Recall Prec F1: {:.2f} , {:.2f} , {:.2f}'.format(100*BROS_recall,100*BROS_prec,100*BROS_Fm))
                if rel_BROS_TP is not None:
                    print('----OVERALL------')
                    BROS_recall = rel_BROS_TP/rel_gt_count
                    BROS_prec = rel_BROS_TP/rel_pred_count
                    print('BROS relationships Recall Prec F1: {:.2f} , {:.2f} , {:.2f}'.format(100*BROS_recall,100*BROS_prec,100*2*BROS_recall*BROS_prec/(BROS_prec+BROS_recall)))
                    #strict_recall = rel_strict_TP/rel_gt_count
                    #strict_prec = rel_strict_TP/rel_pred_count
                    #print('strict relationships Recall Prec F1: {:.2f} , {:.2f} , {:.2f}'.format(100*strict_recall,100*strict_prec,100*2*strict_recall*strict_prec/(strict_prec+strict_recall)))
                if group_TP is not None:
                    group_recall = group_TP/group_gt_count
                    group_prec = group_TP/group_pred_count
                    print('entity Recall Prec F1: {:.2f} , {:.2f} , {:.2f}'.format(100*group_recall,100*group_prec,100*2*group_recall*group_prec/(group_prec+group_recall)))

            if 'save_nns' in config:
                pickle.dump(nns,open(config['save_nns'],'wb'))
            if 'save_spaced' in config:
                #import pdb;pdb.set_trace()
                #spaced = torch.cat(spaced,dim=1).numpy()
                #spaced_val = torch.cat(spaced_val,dim=1).numpy()
                saveSpacedLoc = config['save_spaced']
                lastSlash = saveSpacedLoc.rfind('/')
                if lastSlash>=0:
                    saveSpacedValLoc = saveSpacedLoc[:lastSlash+1]+'val_'+saveSpacedLoc[lastSlash+1:]
                else:
                    saveSpacedValLoc = 'val_'+saveSpacedLoc
                with open(saveSpacedLoc,'wb') as f:
                    pickle.dump(spaced,f)

                with open(saveSpacedValLoc,'wb') as f:
                    pickle.dump(spaced_val,f)

            if 'save_style' in config:
                if len(styles)>0:
                    save_style(saveStyleLoc,len(data_loader),styles,authors,ids,doIds)
                if len(stylesVal)>0:
                    save_style(saveStyleValLoc,len(valid_data_loader),stylesVal,authorsVal,idsVal,doIds)
        elif type(index)==int:
            if index>0:
                instances = train_iter
            else:
                index*=-1
                instances = valid_iter
            batchIndex = index//batchSize
            inBatchIndex = index%batchSize
            for i in range(batchIndex+1):
                instance= instances.next()
            #data, target = data[inBatchIndex:inBatchIndex+1], target[inBatchIndex:inBatchIndex+1]
            #dataT = _to_tensor(gpu,data)
            #output = model(dataT)
            #data = data.cpu().data.numpy()
            #output = output.cpu().data.numpy()
            #target = target.data.numpy()
            #print (output.shape)
            #print ((output.min(), output.amin()))
            #print (target.shape)
            #print ((target.amin(), target.amin()))
            #metricsO = _eval_metrics_ind(metrics,output, target)
            saveFunc(config,instance,model,gpu,metrics,saveDir,batchIndex*batchSize,toEval=toEval)
        else:
            for instance in data_loader:
                if index in instance['imgName']:
                    break
            if index not in instance['imgName']:
                for instance in valid_data_loader:
                    if index in instance['imgName']:
                        break
            if index in instance['imgName']:
                saveFunc(config,instance,model,gpu,metrics,saveDir,0,toEval=toEval)
            else:
                print('{} not found! (on {})'.format(index,instance['imgName']))
                print('{} not found! (on {})'.format(index,instance['imgName']))
    try:
        do =trainer.do_characterization
    except:
        do = False
    if do:
        trainer.displayCharacterization()
Ejemplo n.º 12
0
    def __init__(self, model, loss, metrics, resume, config,
                 data_loader, valid_data_loader=None, train_logger=None):
        super(QRGrowGenTrainer, self).__init__(model, loss, metrics, resume, config, train_logger)
        assert(self.curriculum)
        self.config = config
        if 'loss_params' in config:
            self.loss_params=config['loss_params']
        else:
            self.loss_params={}
        for lossname in self.loss:
            if lossname not in self.loss_params:
                self.loss_params[lossname]={}
        self.lossWeights = config['loss_weights'] if 'loss_weights' in config else {"auto": 1, "recog": 1}
        if data_loader is not None:
            self.batch_size = data_loader.batch_size
            self.data_loader = data_loader
            if 'refresh_data' in dir(self.data_loader.dataset):
                self.data_loader.dataset.refresh_data(None,None,self.logged)
            self.data_loader_iter = iter(data_loader)
        if self.val_step<0:
            self.valid_data_loader=None
            print('Set valid_data_loader to None')
        else:
            self.valid_data_loader = valid_data_loader
        self.valid = True if self.valid_data_loader is not None else False



        self.feature_loss = 'feature' in self.loss
        if 'feature' in self.loss:
            self.model.something.setup_save_features()

        self.to_display={}



        self.gan_loss = 'discriminator' in config['model']
        self.disc_iters = config['trainer']['disc_iters'] if 'disc_iters' in config['trainer'] else 1

        #This text data could be used to randomly sample strings, if we so choose
        text_data_batch_size = config['trainer']['text_data_batch_size'] if 'text_data_batch_size' in config['trainer'] else self.config['data_loader']['batch_size']
        text_words = config['trainer']['text_words'] if 'text_words' in config['trainer'] else False
        if 'a_batch_size' in self.config['data_loader']:
            self.a_batch_size = self.config['data_loader']['a_batch_size']
            text_data_batch_size*=self.config['data_loader']['a_batch_size']
        else:
            self.a_batch_size=1
        #text_data_max_len = config['trainer']['text_data_max_len'] if 'text_data_max_len' in config['trainer'] else 20
        if data_loader is not None:
            if 'text_data' in config['trainer']:
                text_data_max_len = self.data_loader.dataset.max_len()
                characterBalance = config['trainer']['character_balance'] if 'character_balance' in config['trainer'] else False
                text_data_max_len = config['trainer']['text_data_max_len'] if 'text_data_max_len' in config['trainer'] else text_data_max_len
                self.text_data = TextData(config['trainer']['text_data'],config['data_loader']['char_file'],text_data_batch_size,max_len=text_data_max_len,words=text_words,characterBalance=characterBalance) if 'text_data' in config['trainer'] else None

        self.balance_loss = config['trainer']['balance_loss'] if 'balance_loss' in config['trainer'] else False # balance the CTC loss with others as in https://arxiv.org/pdf/1903.00277.pdf, although many of may variations (which are better)
        if self.balance_loss:
            self.parameters = list(model.parameters())
            self.balance_var_x = config['trainer']['balance_var_x'] if 'balance_var_x' in config['trainer'] else None
            if self.balance_loss.startswith('sign_preserve_x'):
                self.balance_x = float(self.balance_loss[self.balance_loss.find('x')+1:])
            self.saved_grads = [] #this will hold the gradients for previous training steps if "no-step" is specified





        if 'align_network' in config['trainer']:
            self.align_network = JoinNet()
            weights = config['trainer']['align_network']
            state_dict=torch.load(config['trainer']['align_network'], map_location=lambda storage, location: storage)
            self.align_network.load_state_dict(state_dict)
            self.align_network.set_requires_grad(False)

        self.no_bg_loss= config['trainer']['no_bg_loss'] if 'no_bg_loss' in config else False
        
        self.sample_disc = self.curriculum.sample_disc if self.curriculum is not None else False
        #if we are going to sample images from the past for the discriminator, these are to store previous generations
        if self.sample_disc:
            self.new_gen=[]
            self.old_gen=[]
            self.store_new_gen_limit = 10
            self.store_old_gen_limit = config['trainer']['store_old_gen_limit'] if 'store_old_gen_limit' in config['trainer'] else 200
            self.new_gen_freq = config['trainer']['new_gen_freq'] if 'new_gen_freq' in config['trainer'] else 0.7
            self.forget_new_freq = config['trainer']['forget_new_freq'] if 'forget_new_freq'  in config['trainer'] else 0.0
            self.old_gen_cache = config['trainer']['old_gen_cache'] if 'old_gen_cache' in config['trainer'] else os.path.join(self.checkpoint_dir,'old_gen_cache')
            if self.old_gen_cache is not None:
                util.ensure_dir(self.old_gen_cache)
                #check for files in cache, so we can resume with them
                for i in range(self.store_old_gen_limit):
                    path = os.path.join(self.old_gen_cache,'{}.pt'.format(i))
                    if os.path.exists(path):
                        self.old_gen.append(path)
                    else:
                        break


        self.WGAN = config['trainer']['WGAN'] if 'WGAN' in config['trainer'] else False
        self.DCGAN = config['trainer']['DCGAN'] if 'DCGAN' in config['trainer'] else False
        if self.DCGAN:
            self.criterion = torch.nn.BCELoss()

        #if 'encoder_weights' in config['trainer']:
        #    snapshot = torch.load(config['trainer']['encoder_weights'],map_location='cpu')
        #    encoder_state_dict={}
        #    for key,value in  snapshot['state_dict'].items():
        #        if key[:8]=='encoder.':
        #            encoder_state_dict[key[8:]] = value
        #    if 'encoder_type' not in config['trainer'] or config['trainer']['encoder_type']=='normal':
        #        self.encoder = Encoder()
        #    elif config['trainer']['encoder_type']=='small':
        #        self.encoder = EncoderSm()
        #    elif config['trainer']['encoder_type']=='2':
        #        self.encoder = Encoder2()
        #    elif config['trainer']['encoder_type']=='2tight':
        #        self.encoder = Encoder2(32)
        #    elif config['trainer']['encoder_type']=='2tighter':
        #        self.encoder = Encoder2(16)
        #    elif config['trainer']['encoder_type']=='3':
        #        self.encoder = Encoder3()
        #    elif config['trainer']['encoder_type']=='32':
        #        self.encoder = Encoder32(256)
        #    else:
        #        raise NotImplementedError('Unknown encoder type: {}'.format(config['trainer']['encoder_type']))
        #    self.encoder.load_state_dict( encoder_state_dict )
        #    if self.with_cuda:
        #        self.encoder = self.encoder.to(self.gpu)

        #This is for saving results during training!
        self.print_dir = config['trainer']['print_dir'] if 'print_dir' in config['trainer'] else None
        if self.print_dir is not None:
            util.ensure_dir(self.print_dir)
        self.print_every = config['trainer']['print_every'] if 'print_every' in config['trainer'] else 100
        self.iter_to_print = self.print_every
        self.serperate_print_every = config['trainer']['serperate_print_every'] if 'serperate_print_every' in config['trainer'] else 2500
        self.last_print_images=defaultdict(lambda: 0)
        self.print_next_gen=False
        self.print_next_auto=False



        if 'alt_data_loader' in config:
            alt_config={'data_loader': config['alt_data_loader'],'validation':{}}
            self.alt_data_loader, alt_valid_data_loader = getDataLoader(alt_config,'train')
            self.alt_data_loader_iter = iter(self.alt_data_loader)
        if 'triplet_data_loader' in config:
            triplet_config={'data_loader': config['triplet_data_loader'],'validation':{}}
            self.triplet_data_loader, triplet_valid_data_loader = getDataLoader(triplet_config,'train')
            self.triplet_data_loader_iter = iter(self.triplet_data_loader)


	#self.optimizer.add_param_group(
	#    {
	#	'params': model.generator.style_emb.parameters(),
	#	'lr': self.lr_schedule * 0.01,
	#	'mult': 0.01,
	#    }
	#) this is done in base trainer
        self.phase = config['trainer']['grow_step_length']
        self.init_size=8
        self.init_step = int(math.log2(self.init_size)) - 2
        self.resolution = 4 * 2 ** self.init_step
        self.lr_schedule = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        adjust_lr(self.optimizer, self.lr_schedule.get(self.resolution, 0.001))
        adjust_lr(self.optimizer_discriminator, self.lr_schedule.get(self.resolution, 0.001))
        max_size=256
        self.max_step = int(math.log2(max_size)) - 2