def build_model(config, device, strict=True, mode='train'):
    ''' build model and change layers depends on loss type'''
    parameters = dict(width_mult=config.model.width_mult,
                    prob_dropout=config.dropout.prob_dropout,
                    type_dropout=config.dropout.type,
                    mu=config.dropout.mu,
                    sigma=config.dropout.sigma,
                    embeding_dim=config.model.embeding_dim,
                    prob_dropout_linear = config.dropout.classifier,
                    theta=config.conv_cd.theta,
                    multi_heads = config.multi_task_learning)

    if config.model.model_type == 'Mobilenet2':
        model = mobilenetv2(**parameters)

        if config.model.pretrained and mode == "train":
            checkpoint_path = config.model.imagenet_weights
            load_checkpoint(checkpoint_path, model, strict=strict, map_location=device)
        elif mode == 'convert':
            model.forward = model.forward_to_onnx

        if (config.loss.loss_type == 'amsoftmax') and (config.loss.amsoftmax.margin_type != 'cross_entropy'):
            model.spoofer = AngleSimpleLinear(config.model.embeding_dim, 2)
        elif config.loss.loss_type == 'soft_triple':
            model.spoofer = SoftTripleLinear(config.model.embeding_dim, 2,
                                             num_proxies=config.loss.soft_triple.K)
    else:
        assert config.model.model_type == 'Mobilenet3'
        if config.model.model_size == 'large':
            model = mobilenetv3_large(**parameters)

            if config.model.pretrained and mode == "train":
                checkpoint_path = config.model.imagenet_weights
                load_checkpoint(checkpoint_path, model, strict=strict, map_location=device)
            elif mode == 'convert':
                model.forward = model.forward_to_onnx
        else:
            assert config.model.model_size == 'small'
            model = mobilenetv3_small(**parameters)

            if config.model.pretrained and mode == "train":
                checkpoint_path = config.model.imagenet_weights
                load_checkpoint(checkpoint_path, model, strict=strict, map_location=device)
            elif mode == 'convert':
                model.forward = model.forward_to_onnx

        if (config.loss.loss_type == 'amsoftmax') and (config.loss.amsoftmax.margin_type != 'cross_entropy'):
            model.scaling = config.loss.amsoftmax.s
            model.spoofer[3] = AngleSimpleLinear(config.model.embeding_dim, 2)
        elif config.loss.loss_type == 'soft_triple':
            model.scaling = config.loss.soft_triple.s
            model.spoofer[3] = SoftTripleLinear(config.model.embeding_dim, 2, num_proxies=config.loss.soft_triple.K)
    return model
예제 #2
0
def main(args):
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/ctw1500_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print('checkpoint path: %s' % args.checkpoint)
    print('init lr: %.8f' % args.lr)
    print('schedule: ', args.schedule)
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = CTW1500Loader(is_transform=True,
                                img_size=args.img_size,
                                kernel_num=kernel_num,
                                min_scale=min_scale)
    #train_loader = ctw_train_loader(data_loader, batch_size=args.batch_size)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)

    #resnet18 and 34 didn't inplement pretrained
    elif args.arch == "resnet18":
        model = models.resnet18(pretrained=False, num_classes=kernel_num)
    elif args.arch == "resnet34":
        model = models.resnet34(pretrained=False, num_classes=kernel_num)

    elif args.arch == "mobilenetv2":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)
    elif args.arch == "mobilenetv3large":
        model = models.mobilenetv3_large(pretrained=False,
                                         num_classes=kernel_num)

    elif args.arch == "mobilenetv3small":
        model = models.mobilenetv3_small(pretrained=False,
                                         num_classes=kernel_num)

    optimizer = tf.keras.optimizers.SGD(learning_rate=args.lr,
                                        momentum=0.99,
                                        decay=5e-4)

    title = 'CTW1500'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'

        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')

        model.load_weights(args.resume)

        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    for epoch in range(start_epoch, args.n_epoch):
        optimizer = get_new_optimizer(args, optimizer, epoch)
        print(
            '\nEpoch: [%d | %d] LR: %f' %
            (epoch + 1, args.n_epoch, optimizer.get_config()['learning_rate']))

        train_loader = ctw_train_loader(data_loader,
                                        batch_size=args.batch_size)

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(train_loader, model, dice_loss,\
                                                                                   optimizer, epoch)

        model.save_weights('%s%s' % (args.checkpoint, '/model_tf/weights'))

        logger.append([
            optimizer.get_config()['learning_rate'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()