def get_model(args): backbone = ResNetV2(layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels, preact=False, stem_type='same', conv_layer=StdConv2dSame) encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width), patch_size=args.patch_size, in_chans=args.channels, num_classes=0, embed_dim=args.dim, depth=args.encoder_depth, num_heads=args.heads, hybrid_backbone=backbone).to(args.device) decoder = CustomARWrapper(TransformerWrapper(num_tokens=args.num_tokens, max_seq_len=args.max_seq_len, attn_layers=Decoder( dim=args.dim, depth=args.num_layers, heads=args.heads, **args.decoder_args)), pad_value=args.pad_token).to(args.device) if 'wandb' in args and args.wandb: import wandb wandb.watch((encoder, decoder.net.attn_layers)) return Model(encoder, decoder, args)
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head backbone = ResNetV2( layers=(3, 4, 9), num_classes=0, global_pool="", in_chans=kwargs.get("in_chans", 3), preact=False, stem_type="same", conv_layer=StdConv2dSame, ) model_kwargs = dict( embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs, ) model = _create_vision_transformer( "vit_base_resnet50_224_in21k", pretrained=pretrained, **model_kwargs ) return model
def get_model(args, training=False): backbone = ResNetV2(layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels, preact=False, stem_type='same', conv_layer=StdConv2dSame) min_patch_size = 2**(len(args.backbone_layers) + 1) def embed_layer(**x): ps = x.pop('patch_size', min_patch_size) assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size return HybridEmbed(**x, patch_size=ps // min_patch_size, backbone=backbone) encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width), patch_size=args.patch_size, in_chans=args.channels, num_classes=0, embed_dim=args.dim, depth=args.encoder_depth, num_heads=args.heads, embed_layer=embed_layer).to(args.device) decoder = CustomARWrapper(TransformerWrapper(num_tokens=args.num_tokens, max_seq_len=args.max_seq_len, attn_layers=Decoder( dim=args.dim, depth=args.num_layers, heads=args.heads, **args.decoder_args)), pad_value=args.pad_token).to(args.device) if 'wandb' in args and args.wandb: import wandb wandb.watch((encoder, decoder.net.attn_layers)) model = Model(encoder, decoder, args) if training: # check if largest batch can be handled by system im = torch.empty(args.batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() seq = torch.randint(0, args.num_tokens, (args.batchsize, args.max_seq_len), device=args.device).long() decoder(seq, context=encoder(im)).sum().backward() model.zero_grad() torch.cuda.empty_cache() del im, seq return model
def initialize(arguments=None): if arguments is None: arguments = Munch({ 'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': True, 'no_resize': False }) logging.getLogger().setLevel(logging.FATAL) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' with open(arguments.config, 'r') as f: params = yaml.load(f, Loader=yaml.FullLoader) args = parse_args(Munch(params)) args.update(**vars(arguments)) args.wandb = False args.device = 'cuda' if torch.cuda.is_available( ) and not args.no_cuda else 'cpu' model = get_model(args) model.load_state_dict(torch.load(args.checkpoint, map_location=args.device)) if 'image_resizer.pth' in os.listdir(os.path.dirname( args.checkpoint)) and not arguments.no_resize: image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(args.max_dimensions) // 32, global_pool='avg', in_chans=1, drop_rate=.05, preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device) image_resizer.load_state_dict( torch.load(os.path.join(os.path.dirname(args.checkpoint), 'image_resizer.pth'), map_location=args.device)) image_resizer.eval() else: image_resizer = None tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer) return args, model, image_resizer, tokenizer
def initialize(arguments): filename = join(dirname(__file__), arguments.config) with open(filename, 'r') as f: params = yaml.load(f, Loader=yaml.FullLoader) args = Munch(params) args.update(**vars(arguments)) args.wandb = False args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu' model = get_model(args) model.load_state_dict(torch.load(args.checkpoint, map_location=args.device)) if 'image_resizer.pth' in os.listdir(os.path.dirname(args.checkpoint)) and not arguments.no_resize: image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=22, global_pool='avg', in_chans=1, drop_rate=.05, preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device) image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(args.checkpoint), 'image_resizer.pth'), map_location=args.device)) image_resizer.eval() else: image_resizer = None tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer) return args, model, image_resizer, tokenizer
def main(args): # data dataloader = Im2LatexDataset().load(args.data) dataloader.update(batchsize=args.batchsize, test=False, max_dimensions=args.max_dimensions, keep_smaller_batches=True, device=args.device) valloader = Im2LatexDataset().load(args.valdata) valloader.update(batchsize=args.batchsize, test=True, max_dimensions=args.max_dimensions, keep_smaller_batches=True, device=args.device) # model model = ResNetV2(layers=[2, 3, 3], num_classes=int(max(args.max_dimensions) // 32), global_pool='avg', in_chans=args.channels, drop_rate=.05, preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device) if args.resume: model.load_state_dict(torch.load(args.resume)) opt = Adam(model.parameters(), lr=args.lr) crit = nn.CrossEntropyLoss() sched = OneCycleLR(opt, .005, total_steps=args.num_epochs * len(dataloader)) global bestacc bestacc = val(valloader, model, args.valbatches, args.device) def train_epoch(sched=None): iter(dataloader) dset = tqdm(range(len(dataloader))) for i in dset: im, label = prepare_data(dataloader) if im is not None: if im.shape[-1] > dataloader.max_dimensions[0] or im.shape[ -2] > dataloader.max_dimensions[1]: continue opt.zero_grad() label = label.to(args.device) pred = model(im.to(args.device)) loss = crit(pred, label) if i % 2 == 0: dset.set_description('Loss: %.4f' % loss.item()) loss.backward() opt.step() if sched is not None: sched.step() if (i + 1) % args.sample_freq == 0 or i + 1 == len(dset): acc = val(valloader, model, args.valbatches, args.device) print('Accuracy %.2f' % (100 * acc), '%') global bestacc if acc > bestacc: torch.save(model.state_dict(), args.out) bestacc = acc for _ in range(args.num_epochs): train_epoch(sched)