Exemple #1
0
def get_model(args):
    backbone = ResNetV2(layers=args.backbone_layers,
                        num_classes=0,
                        global_pool='',
                        in_chans=args.channels,
                        preact=False,
                        stem_type='same',
                        conv_layer=StdConv2dSame)
    encoder = CustomVisionTransformer(img_size=(args.max_height,
                                                args.max_width),
                                      patch_size=args.patch_size,
                                      in_chans=args.channels,
                                      num_classes=0,
                                      embed_dim=args.dim,
                                      depth=args.encoder_depth,
                                      num_heads=args.heads,
                                      hybrid_backbone=backbone).to(args.device)

    decoder = CustomARWrapper(TransformerWrapper(num_tokens=args.num_tokens,
                                                 max_seq_len=args.max_seq_len,
                                                 attn_layers=Decoder(
                                                     dim=args.dim,
                                                     depth=args.num_layers,
                                                     heads=args.heads,
                                                     **args.decoder_args)),
                              pad_value=args.pad_token).to(args.device)
    if 'wandb' in args and args.wandb:
        import wandb
        wandb.watch((encoder, decoder.net.attn_layers))
    return Model(encoder, decoder, args)
Exemple #2
0
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
    """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
    backbone = ResNetV2(
        layers=(3, 4, 9),
        num_classes=0,
        global_pool="",
        in_chans=kwargs.get("in_chans", 3),
        preact=False,
        stem_type="same",
        conv_layer=StdConv2dSame,
    )
    model_kwargs = dict(
        embed_dim=768,
        depth=12,
        num_heads=12,
        hybrid_backbone=backbone,
        representation_size=768,
        **kwargs,
    )
    model = _create_vision_transformer(
        "vit_base_resnet50_224_in21k", pretrained=pretrained, **model_kwargs
    )
    return model
Exemple #3
0
def get_model(args, training=False):
    backbone = ResNetV2(layers=args.backbone_layers,
                        num_classes=0,
                        global_pool='',
                        in_chans=args.channels,
                        preact=False,
                        stem_type='same',
                        conv_layer=StdConv2dSame)
    min_patch_size = 2**(len(args.backbone_layers) + 1)

    def embed_layer(**x):
        ps = x.pop('patch_size', min_patch_size)
        assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size
        return HybridEmbed(**x,
                           patch_size=ps // min_patch_size,
                           backbone=backbone)

    encoder = CustomVisionTransformer(img_size=(args.max_height,
                                                args.max_width),
                                      patch_size=args.patch_size,
                                      in_chans=args.channels,
                                      num_classes=0,
                                      embed_dim=args.dim,
                                      depth=args.encoder_depth,
                                      num_heads=args.heads,
                                      embed_layer=embed_layer).to(args.device)

    decoder = CustomARWrapper(TransformerWrapper(num_tokens=args.num_tokens,
                                                 max_seq_len=args.max_seq_len,
                                                 attn_layers=Decoder(
                                                     dim=args.dim,
                                                     depth=args.num_layers,
                                                     heads=args.heads,
                                                     **args.decoder_args)),
                              pad_value=args.pad_token).to(args.device)
    if 'wandb' in args and args.wandb:
        import wandb
        wandb.watch((encoder, decoder.net.attn_layers))
    model = Model(encoder, decoder, args)
    if training:
        # check if largest batch can be handled by system
        im = torch.empty(args.batchsize,
                         args.channels,
                         args.max_height,
                         args.min_height,
                         device=args.device).float()
        seq = torch.randint(0,
                            args.num_tokens,
                            (args.batchsize, args.max_seq_len),
                            device=args.device).long()
        decoder(seq, context=encoder(im)).sum().backward()
        model.zero_grad()
        torch.cuda.empty_cache()
        del im, seq
    return model
Exemple #4
0
def initialize(arguments=None):
    if arguments is None:
        arguments = Munch({
            'config': 'settings/config.yaml',
            'checkpoint': 'checkpoints/weights.pth',
            'no_cuda': True,
            'no_resize': False
        })
    logging.getLogger().setLevel(logging.FATAL)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    with open(arguments.config, 'r') as f:
        params = yaml.load(f, Loader=yaml.FullLoader)
    args = parse_args(Munch(params))
    args.update(**vars(arguments))
    args.wandb = False
    args.device = 'cuda' if torch.cuda.is_available(
    ) and not args.no_cuda else 'cpu'

    model = get_model(args)
    model.load_state_dict(torch.load(args.checkpoint,
                                     map_location=args.device))

    if 'image_resizer.pth' in os.listdir(os.path.dirname(
            args.checkpoint)) and not arguments.no_resize:
        image_resizer = ResNetV2(layers=[2, 3, 3],
                                 num_classes=max(args.max_dimensions) // 32,
                                 global_pool='avg',
                                 in_chans=1,
                                 drop_rate=.05,
                                 preact=True,
                                 stem_type='same',
                                 conv_layer=StdConv2dSame).to(args.device)
        image_resizer.load_state_dict(
            torch.load(os.path.join(os.path.dirname(args.checkpoint),
                                    'image_resizer.pth'),
                       map_location=args.device))
        image_resizer.eval()
    else:
        image_resizer = None
    tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer)
    return args, model, image_resizer, tokenizer
Exemple #5
0
def initialize(arguments):
    filename = join(dirname(__file__), arguments.config)
    with open(filename, 'r') as f:
        params = yaml.load(f, Loader=yaml.FullLoader)
    args = Munch(params)
    args.update(**vars(arguments))
    args.wandb = False
    args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'

    model = get_model(args)
    model.load_state_dict(torch.load(args.checkpoint, map_location=args.device))

    if 'image_resizer.pth' in os.listdir(os.path.dirname(args.checkpoint)) and not arguments.no_resize:
        image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=22, global_pool='avg', in_chans=1, drop_rate=.05,
                                 preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device)
        image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(args.checkpoint), 'image_resizer.pth'), map_location=args.device))
        image_resizer.eval()
    else:
        image_resizer = None
    tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer)
    return args, model, image_resizer, tokenizer
def main(args):
    # data
    dataloader = Im2LatexDataset().load(args.data)
    dataloader.update(batchsize=args.batchsize,
                      test=False,
                      max_dimensions=args.max_dimensions,
                      keep_smaller_batches=True,
                      device=args.device)
    valloader = Im2LatexDataset().load(args.valdata)
    valloader.update(batchsize=args.batchsize,
                     test=True,
                     max_dimensions=args.max_dimensions,
                     keep_smaller_batches=True,
                     device=args.device)

    # model
    model = ResNetV2(layers=[2, 3, 3],
                     num_classes=int(max(args.max_dimensions) // 32),
                     global_pool='avg',
                     in_chans=args.channels,
                     drop_rate=.05,
                     preact=True,
                     stem_type='same',
                     conv_layer=StdConv2dSame).to(args.device)
    if args.resume:
        model.load_state_dict(torch.load(args.resume))
    opt = Adam(model.parameters(), lr=args.lr)
    crit = nn.CrossEntropyLoss()
    sched = OneCycleLR(opt,
                       .005,
                       total_steps=args.num_epochs * len(dataloader))
    global bestacc
    bestacc = val(valloader, model, args.valbatches, args.device)

    def train_epoch(sched=None):
        iter(dataloader)
        dset = tqdm(range(len(dataloader)))
        for i in dset:
            im, label = prepare_data(dataloader)
            if im is not None:
                if im.shape[-1] > dataloader.max_dimensions[0] or im.shape[
                        -2] > dataloader.max_dimensions[1]:
                    continue
                opt.zero_grad()
                label = label.to(args.device)

                pred = model(im.to(args.device))
                loss = crit(pred, label)
                if i % 2 == 0:
                    dset.set_description('Loss: %.4f' % loss.item())
                loss.backward()
                opt.step()
                if sched is not None:
                    sched.step()
            if (i + 1) % args.sample_freq == 0 or i + 1 == len(dset):
                acc = val(valloader, model, args.valbatches, args.device)
                print('Accuracy %.2f' % (100 * acc), '%')
                global bestacc
                if acc > bestacc:
                    torch.save(model.state_dict(), args.out)
                    bestacc = acc

    for _ in range(args.num_epochs):
        train_epoch(sched)