Пример #1
0
    def from_pretrained(cls, model_name, weights_path=None, advprop=False, 
                        in_channels=3, num_classes=1000, **override_params):
        """create an efficientnet model according to name.

        Args:
            model_name (str): Name for efficientnet.
            weights_path (None or str): 
                str: path to pretrained weights file on the local disk.
                None: use pretrained weights downloaded from the Internet.
            advprop (bool): 
                Whether to load pretrained weights
                trained with advprop (valid when weights_path is None).
            in_channels (int): Input data's channel number.
            num_classes (int): 
                Number of categories for classification.
                It controls the output size for final linear layer.
            override_params (other key word params): 
                Params to override model's global_params.
                Optional key:
                    'width_coefficient', 'depth_coefficient',
                    'image_size', 'dropout_rate',
                    'num_classes', 'batch_norm_momentum',
                    'batch_norm_epsilon', 'drop_connect_rate',
                    'depth_divisor', 'min_depth'

        Returns:
            A pretrained efficientnet model.
        """
        model = cls.from_name(model_name, num_classes = num_classes, **override_params)
        load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
        model._change_in_channels(in_channels)
        return model
Пример #2
0
 def from_pretrained(cls, model_name, num_classes=1000):
     model = EfficientNet.from_name(
         model_name, override_params={'num_classes': num_classes})
     load_pretrained_weights(model,
                             model_name,
                             load_fc=(num_classes == 1000))
     return model
Пример #3
0
    def net_initialize(self,
                       startup_prog=None,
                       pretrain_weights=None,
                       resume_weights=None):
        if startup_prog is None:
            startup_prog = fluid.default_startup_program()
        self.exe.run(startup_prog)
        if resume_weights is not None:
            logging.info("Resume weights from {}".format(resume_weights))
            if not osp.exists(resume_weights):
                raise Exception("Path {} not exists.".format(resume_weights))
            fluid.load(self.train_prog, osp.join(resume_weights, 'model'),
                       self.exe)
            # Check is path ended by path spearator
            if resume_weights[-1] == os.sep:
                resume_weights = resume_weights[0:-1]
            epoch_name = osp.basename(resume_weights)
            # If resume weights is end of digit, restore epoch status
            epoch = epoch_name.split('_')[-1]
            if epoch.isdigit():
                self.begin_epoch = int(epoch)
            else:
                raise ValueError("Resume model path is not valid!")
            logging.info("Model checkpoint loaded successfully!")

        elif pretrain_weights is not None:
            logging.info(
                "Load pretrain weights from {}.".format(pretrain_weights))
            utils.load_pretrained_weights(self.exe, self.train_prog,
                                          pretrain_weights)
Пример #4
0
def EfficientNetB0(pretrained=True, frozen_blocks=5):
    model = EfficientNet('efficientnet-b0', frozen_blocks=frozen_blocks)
    if pretrained:
        load_pretrained_weights(model,
                                'efficientnet-b0',
                                weights_path=None,
                                load_fc=False,
                                advprop=False)

    return model
Пример #5
0
 def from_pretrained(cls, model_name, num_classes=1000, in_channels=3):
     model = cls.from_name(model_name,
                           override_params={'num_classes': num_classes})
     load_pretrained_weights(model,
                             model_name,
                             load_fc=(num_classes == 1000))
     if in_channels != 3:
         Conv2d = get_same_padding_conv2d(
             image_size=model._global_params.image_size)
         out_channels = round_filters(32, model._global_params)
         model._conv_stem = Conv2d(in_channels,
                                   out_channels,
                                   kernel_size=3,
                                   stride=2,
                                   bias=False)
     return model
Пример #6
0
def main():
    args = get_arguments()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    logger = Logger(args)
    model = Effnet(num_classes=5,
                   width_coeffficient=1.2,
                   depth_coefficient=1.4,
                   drop_out=0.3)
    model = load_pretrained_weights(model, 'efficientnet-b3')
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=1e-4,
                          nesterov=True)
    criterion = nn.CrossEntropyLoss()
    train_loader, val_loader, test_loader = get_data()
    scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
    # scheduler = lr_scheduler.CosineAnnealingLr(optimizer,len(train_loader),eta_min=1e-6)
    model.cuda()
    model = torch.nn.DataParallel(model, [0, 1, 2, 3])
    best_val_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        train(args, model, train_loader, optimizer, criterion, epoch,
              scheduler, logger)
        val_acc = val(args, model, val_loader, optimizer, criterion, epoch,
                      logger)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_checkpoint(args,
                            model.state_dict(),
                            filename='epoch_{}_best_{}.pth'.format(
                                epoch, val_acc))
Пример #7
0
 def from_name(cls, model_name, heads, head_conv, pretrained=False):
     cls._check_model_name_is_valid(model_name)
     blocks_args, global_params = get_model_params(model_name, None)
     model = EfficientNet(blocks_args,
                          global_params,
                          heads,
                          head_conv=head_conv)
     if pretrained:
         model = load_pretrained_weights(model, model_name)
     return model
Пример #8
0
    def __init__(self, input_width: int, input_height: int,
                 weight_filepath: str, batch_size: str, num_classes: int,
                 patch_height: int, patch_width: int, norm_mean: List[float],
                 norm_std: List[float], GPU: bool):

        self._input_width = input_width
        self._input_height = input_height
        self._weight_filepath = weight_filepath
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.patch_size = (patch_height, patch_width)
        self.norm_mean = norm_mean
        self.norm_std = norm_std
        self.GPU = GPU
        self._model = osnet_ibn_x1_0(num_classes=self.num_classes,
                                     loss=OsNetEncoder.LOSS,
                                     pretrained=OsNetEncoder.PRETRAINED_MODEL,
                                     use_gpu=self.GPU)
        self._model.eval()  # Set the torch model for evaluation
        self.weights_loaded = load_pretrained_weights(
            model=self._model, weight_path=self._weight_filepath)
        if self.GPU:
            self._model = self._model.cuda()
Пример #9
0
    parser.add_argument('--data_path', default='/path/to/davis/', type=str)
    parser.add_argument("--n_last_frames", type=int, default=7, help="number of preceeding frames")
    parser.add_argument("--size_mask_neighborhood", default=12, type=int,
        help="We restrict the set of source nodes considered to a spatial neighborhood of the query node")
    parser.add_argument("--topk", type=int, default=5, help="accumulate label from top k neighbors")
    parser.add_argument("--bs", type=int, default=6, help="Batch size, try to reduce if OOM")
    args = parser.parse_args()

    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))

    # building network
    model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
    print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
    model.cuda()
    utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
    for param in model.parameters():
        param.requires_grad = False
    model.eval()

    color_palette = []
    for line in urlopen("https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"):
        color_palette.append([int(i) for i in line.decode("utf-8").split('\n')[0].split(" ")])
    color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1,3)

    video_list = open(os.path.join(args.data_path, "ImageSets/2017/val.txt")).readlines()
    for i, video_name in enumerate(video_list):
        video_name = video_name.strip()
        print(f'[{i}/{len(video_list)}] Begin to segmentate video {video_name}.')
        video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name)
        frame_list = read_frame_list(video_dir)
Пример #10
0
def eval_linear(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True

    # ============ preparing data ... ============
    train_transform = pth_transforms.Compose([
        pth_transforms.RandomResizedCrop(224),
        pth_transforms.RandomHorizontalFlip(),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    val_transform = pth_transforms.Compose([
        pth_transforms.Resize(256, interpolation=3),
        pth_transforms.CenterCrop(224),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
    dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")

    # ============ building network ... ============
    model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
    model.cuda()
    model.eval()
    print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
    # load weights to evaluate
    utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

    linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels)
    linear_classifier = linear_classifier.cuda()
    linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])

    # set optimizer
    optimizer = torch.optim.SGD(
        linear_classifier.parameters(),
        args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
        momentum=0.9,
        weight_decay=0, # we do not apply weight decay
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)

    # Optionally resume from a checkpoint
    to_restore = {"epoch": 0, "best_acc": 0.}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=linear_classifier,
        optimizer=optimizer,
        scheduler=scheduler,
    )
    start_epoch = to_restore["epoch"]
    best_acc = to_restore["best_acc"]

    for epoch in range(start_epoch, args.epochs):
        train_loader.sampler.set_epoch(epoch)

        train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
        scheduler.step()

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch}
        if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
            test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
            print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
            best_acc = max(best_acc, test_stats["acc1"])
            print(f'Max accuracy so far: {best_acc:.2f}%')
            log_stats = {**{k: v for k, v in log_stats.items()},
                         **{f'test_{k}': v for k, v in test_stats.items()}}
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": linear_classifier.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "best_acc": best_acc,
            }
            torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
    print("Training of the supervised linear classifier on frozen features completed.\n"
                "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
Пример #11
0
 def load(self, path):
     load_pretrained_weights(self, path)
Пример #12
0
def extract_feature_pipeline(args):
    # ============ preparing data ... ============
    transform = pth_transforms.Compose([pth_transforms.CenterCrop(96)])
    #dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
    #dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)

    dataset_train = ReturnIndexDataset(args.data_path,
                                       "train",
                                       transform=transform,
                                       tansform_coord=None,
                                       classes=None,
                                       seasons=None,
                                       split_by_region=True,
                                       download=False)
    dataset_val = ReturnIndexDataset(args.data_path,
                                     "val",
                                     transform=transform,
                                     tansform_coord=None,
                                     classes=None,
                                     seasons=None,
                                     split_by_region=True,
                                     download=False)
    """
    args.data_path, "train", transform=transform, tansform_coord=None,
                 classes=None, seasons=None, split_by_region=True, download=False
                 """

    sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    print(
        f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs."
    )

    # ============ building network ... ============
    model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
    utils.replace_input_layer(model, inchannels=13)
    print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
    model.cuda()
    utils.load_pretrained_weights(model, args.pretrained_weights,
                                  args.checkpoint_key, args.arch,
                                  args.patch_size)
    model.eval()

    # ============ extract features ... ============
    print("Extracting features for train set...")
    train_features = extract_features(model, data_loader_train)
    print("Extracting features for val set...")
    test_features = extract_features(model, data_loader_val)

    if utils.get_rank() == 0:
        train_features = nn.functional.normalize(train_features, dim=1, p=2)
        test_features = nn.functional.normalize(test_features, dim=1, p=2)

    train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
    test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
    # save features and labels
    if args.dump_features and dist.get_rank() == 0:
        torch.save(train_features.cpu(),
                   os.path.join(args.dump_features, "trainfeat.pth"))
        torch.save(test_features.cpu(),
                   os.path.join(args.dump_features, "testfeat.pth"))
        torch.save(train_labels.cpu(),
                   os.path.join(args.dump_features, "trainlabels.pth"))
        torch.save(test_labels.cpu(),
                   os.path.join(args.dump_features, "testlabels.pth"))
    return train_features, test_features, train_labels, test_labels
Пример #13
0
def extract_feature_pipeline(args):
    # ============ preparing data ... ============
    transform = pth_transforms.Compose([
        pth_transforms.Resize(256, interpolation=3),
        pth_transforms.CenterCrop(224),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"),
                                       transform=transform)
    dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"),
                                     transform=transform)
    sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    print(
        f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs."
    )

    # ============ building network ... ============
    model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
    print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
    model.cuda()
    utils.load_pretrained_weights(model, args.pretrained_weights,
                                  args.checkpoint_key, args.arch,
                                  args.patch_size)
    model.eval()

    # ============ extract features ... ============
    print("Extracting features for train set...")
    train_features = extract_features(model, data_loader_train)
    print("Extracting features for val set...")
    test_features = extract_features(model, data_loader_val)

    if utils.get_rank() == 0:
        train_features = nn.functional.normalize(train_features, dim=1, p=2)
        test_features = nn.functional.normalize(test_features, dim=1, p=2)

    train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
    test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
    # save features and labels
    if args.dump_features and dist.get_rank() == 0:
        torch.save(train_features.cpu(),
                   os.path.join(args.dump_features, "trainfeat.pth"))
        torch.save(test_features.cpu(),
                   os.path.join(args.dump_features, "testfeat.pth"))
        torch.save(train_labels.cpu(),
                   os.path.join(args.dump_features, "trainlabels.pth"))
        torch.save(test_labels.cpu(),
                   os.path.join(args.dump_features, "testlabels.pth"))
    return train_features, test_features, train_labels, test_labels
Пример #14
0
 def from_pretrained(cls, model_name):
     model = EfficientNet.from_name(model_name)
     model = load_pretrained_weights(model, model_name)
     return model
Пример #15
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)
Пример #16
0
sequencers = {}
for split in splits:
    data_manager = data_managers[split]
    image_shape = (args.image_size, args.image_size)
    processor = AugmentHandSegmentation(image_shape)
    sequencers[split] = ProcessingSequence(processor, args.batch_size,
                                           datasets[split])

model = Hand_Segmentation_Net()
loss = CategoricalCrossentropy(from_logits=True)

model.compile(loss=loss, optimizer=Adam(), metrics=['mean_squared_error'])

if args.load_pretrained_weights:
    model = load_pretrained_weights(args.pretrained_weights_path,
                                    model=model,
                                    num_layers=16)

# creating directory for experiment
callbacks = []
experiment_label = '_'.join([args.dataset, model.name, args.run_label])
experiment_path = os.path.join(args.save_path, experiment_label)
if not os.path.exists(experiment_path):
    os.makedirs(experiment_path)

# setting additional callbacks
log = CSVLogger(os.path.join(experiment_path, 'optimization.log'))
stop = EarlyStopping(patience=args.stop_patience)
plateau = ReduceLROnPlateau(patience=args.reduce_patience)
save_filename = os.path.join(experiment_path, 'model.hdf5')
save = ModelCheckpoint(save_filename, save_best_only=True)
IMG_HEIGHT_WIDTH = (299, 299)
INPUT_SHAPE = (299, 299, 3)
BATCH_SIZE = 64
OPTIMIZER = 'adam'
FREEZE_PROPORTION = 0.99
LIST_OF_NAMES = False

# data generators 
weights_path = "/home/nhannguyen/Kaggle_IEEE/model_weights/inception_resnet_v2/"
train_generator, validation_generator = train_validation_generator(batch_size=BATCH_SIZE,
                                                                   img_height_width=IMG_HEIGHT_WIDTH)
NUM_TRAIN_SAMPLES = 147060
NUM_VALID_SAMPLES = 1440
 
# create computational graph and load pre-trained ImageNet weights
inceptionresnetv2 = load_pretrained_weights('InceptionResNetV2', input_shape=INPUT_SHAPE)

# overwrite ImageNet weights with weights obtained from last training epoch
inceptionresnetv2.load_weights(weights_path + 'multigpu7.hdf5')

# fine-tune
freeze_layers(list_of_names=LIST_OF_NAMES,
              trainable_layers_names=None,
              model=inceptionresnetv2,
              freeze_proportion=FREEZE_PROPORTION)
parallel_model = multi_gpu_model(model=inceptionresnetv2, gpus=NUM_GPUS)
multi_stages_epochs = [10, 10, 10, 20, 20, 20, 20, 20]
multi_stages_learning_rate = [1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-4, 1e-4, 1e-4]
for i, stage_epochs in enumerate(multi_stages_epochs):
    history = train_on_multi_gpus(parallel_model=parallel_model,
                                  train_generator=train_generator,