def main(device):
    os.environ["CUDA_VISIBLE_DEVICES"] = device
    torch.backends.cudnn.benchmark = True
    
    train_data = COCO(CONFIG.DATASET.COCO, 'train2017', req_label=True, req_augment=True, scales=CONFIG.DATASET.SCALES, flip=True)
    val_data = COCO(CONFIG.DATASET.COCO, 'val2017', req_label=True)
    train_loader = DataLoader(train_data, batch_size=CONFIG.DATALOADER.BATCH_SIZE.TRAIN, drop_last=True,
                            shuffle=True, num_workers=CONFIG.DATALOADER.WORKERS, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=CONFIG.DATALOADER.BATCH_SIZE.TRAIN, drop_last=True,
                            shuffle=True, num_workers=CONFIG.DATALOADER.WORKERS, pin_memory=True)
    global total_iter
    total_iter = float(int(CONFIG.SOLVER.EPOCHS * len(train_data) / CONFIG.DATALOADER.BATCH_SIZE.TRAIN))
    print("Dataset Ready.")

    seg_model = nn.DataParallel(BiSeNet(CONFIG.DATASET.CLASS_NUM, CONFIG.DATASET.IGNORE_LABEL).cuda())
    optimizer = torch.optim.SGD(seg_model.parameters(), lr=CONFIG.SOLVER.INITIAL_LR, momentum=CONFIG.SOLVER.MOMENTUM, 
                                weight_decay=CONFIG.SOLVER.WEIGHT_DECAY, nesterov=True)
    recorder = {"Train":{'loss':[], 'PA':[], 'MA':[], 'MI':[]},
                "Val":{'loss':[], 'PA':[], 'MA':[], 'MI':[]}}
    print("Model Ready.")

    for epoch in range(CONFIG.SOLVER.EPOCHS):
        Epoch_Step(seg_model, train_loader, optimizer, epoch, recorder)
        Epoch_Step(seg_model, val_loader, optimizer, epoch, recorder, Train=False)

        name = "Epoch_{:d}".format(epoch)
        results = {
            'model':seg_model.state_dict(),
            'recorder':recorder
        }
        torch.save(results, os.path.join(CONFIG.SOLVER.SAVE_PATH, 'bisenet_{}.pth'.format(name)))
def main(device):
    os.environ["CUDA_VISIBLE_DEVICES"] = device
    torch.backends.cudnn.benchmark = True

    result = torch.load(os.path.join('model', CONFIG_VQA.MODEL.BISENET))

    seg_model = nn.DataParallel(
        BiSeNet(CONFIG_BIS.DATASET.CLASS_NUM,
                CONFIG_BIS.DATASET.IGNORE_LABEL).cuda())
    seg_model.load_state_dict(result['model'])
    seg_model.eval()
    resnet = nn.DataParallel(ResNet152().cuda())
    resnet.eval()
    pool = nn.AdaptiveMaxPool2d(14).cuda()
    print('Model Ready.')

    for split, name in zip(['train2014', 'val2014'],
                           ['train_image', 'val_image']):
        loader = DataLoader(COCO(CONFIG_VQA.DATASET.COCO, split),
                            batch_size=CONFIG_BIS.DATALOADER.BATCH_SIZE.TEST,
                            pin_memory=True,
                            num_workers=CONFIG_BIS.DATALOADER.WORKERS,
                            shuffle=False)
        feature_shape = (len(loader.dataset), 2048, 14, 14)
        semantic_shape = (len(loader.dataset), 182, 14, 14)
        id_shape = (len(loader.dataset), )

        with h5.File(CONFIG_VQA.DATASET.COCO_PROCESSED, libver='latest') as f:
            features = f.create_dataset(name + '_feature',
                                        shape=feature_shape,
                                        dtype='float16')
            semantics = f.create_dataset(name + '_semantic',
                                         shape=semantic_shape,
                                         dtype='float16')
            ids = f.create_dataset(name + '_ids',
                                   shape=id_shape,
                                   dtype='int32')

            with torch.no_grad():
                i = 0
                for image, COCOid in tqdm(loader):
                    image = image.cuda()
                    feature = resnet(image).detach().cpu().numpy().astype(
                        np.float16)

                    score = seg_model(image)
                    score = pool(score).detach().cpu().numpy().astype(
                        np.float16)

                    features[i:(i + image.size(0))] = feature
                    semantics[i:(i + image.size(0))] = score
                    ids[i:(i +
                           image.size(0))] = COCOid.detach().numpy().astype(
                               np.int32)

                    i += image.size(0)
    args = parser.parse_args()
    print(args)

    print('Meshed-Memory Transformer Training')

    writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name))

    # Pipeline for image regions
    image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False)

    # Pipeline for text
    text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
                           remove_punctuation=True, nopoints=False)

    # Create the dataset
    dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
    train_dataset, val_dataset, test_dataset = dataset.splits

    if not os.path.isfile('vocab_%s.pkl' % args.exp_name):
        print("Building vocabulary")
        text_field.build_vocab(train_dataset, val_dataset, min_freq=5)
        pickle.dump(text_field.vocab, open('vocab_%s.pkl' % args.exp_name, 'wb'))
    else:
        text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb'))

    # Model and dataloaders
    encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
                                     attention_module_kwargs={'m': args.m})
    decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
    model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)