Ejemplo n.º 1
0
def build_model(config, device, train=True):
    # load model
    if config['model'] == 'default':
        net = model.Resnet50()
    elif config['model'] == 'fused':
        net = model_fused.Resnet50()
    elif config['model'] == 'quant':
        net = model_quant.Resnet50()
    elif config['model'] == 'tf':
        net = model_tf.Resnet50()
    elif config['model'] == 'tf_fused':
        net = model_tf_fused.Resnet50()
    else:
        raise ValueError('cannot load model, check config file')
    # load loss
    if config['loss'] == 'cross_entropy':
        loss_fn = nn.CrossEntropyLoss()
    else:
        raise ValueError('cannot load loss, check config file')

    net = net.to(device)
    loss_fn = loss_fn.to(device)

    if not train:
        return net, loss_fn
    # load optimizer
    if config['optimizer'] == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           net.parameters()),
                                    lr=config['learning_rate'],
                                    momentum=0.9,
                                    weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            net.parameters()),
                                     lr=config['learning_rate'],
                                     weight_decay=config['weight_decay'])
    else:
        raise ValueError('cannot load optimizer, check config file')
    # load scheduler
    if config['scheduler'] == 'cosine':
        scheduler_step = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['t_max'])
    elif config['scheduler'] == 'step':
        scheduler_step = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config['lr_decay_every'],
            gamma=config["lr_decay"])
    else:
        raise ValueError('cannot load scheduler, check config file')
    scheduler = GradualWarmupScheduler(optimizer,
                                       multiplier=config['lr_multiplier'],
                                       total_epoch=config['lr_epoch'],
                                       after_scheduler=scheduler_step)

    return net, loss_fn, optimizer, scheduler
Ejemplo n.º 2
0
''' ######################## < Step 2 > Create instances ######################## '''

# Build dataloader
print(
    '\n[1 / 3]. Build data loader. Depending on your environment, this may take several minutes..'
)
dloader, dlen = data_loader(dataset_root=config.dataset_root,
                            resize=config.resize,
                            crop=config.crop,
                            batch_size=config.batch_size,
                            num_workers=config.num_workers,
                            type='encoder_train')

# Build models
print('\n[2 / 3]. Build models.. ')
encoder = nn.DataParallel(model.Resnet50(dim=config.feature_dim)).to(dev)
momentum_encoder = nn.DataParallel(
    model.Resnet50(dim=config.feature_dim)).to(dev)

# loss history
loss_hist = []

# If resume, load ckpt and loss history
if config.resume:
    file_name = 'ckpt_' + str(config.start_epoch) + '.pkl'
    ckpt = torch.load(os.path.join(weight_path, file_name))
    encoder.load_state_dict(ckpt['encoder'])

    try:
        with open(os.path.join(loss_path, 'loss.pkl'), 'rb') as f:
            iter_per_epoch = int(dlen / config.batch_size)
Ejemplo n.º 3
0
                                                         mode='train',
                                                         model_type='rnn'),
                                               batch_size=args.train_batch,
                                               num_workers=args.workers,
                                               shuffle=True)
    #    collate_fn=data.collate_fn)

    val_loader = torch.utils.data.DataLoader(data.Data(args,
                                                       mode='val',
                                                       model_type='rnn'),
                                             batch_size=args.train_batch,
                                             num_workers=args.workers,
                                             shuffle=False)
    ''' load model '''
    print('===> prepare model ...')
    feature_extractor, RNNClassifier = model.Resnet50(), model.RNNClassifier(
        args)

    feature_extractor.cuda()
    RNNClassifier.cuda()

    clf_criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(RNNClassifier.parameters(),
                                lr=args.lr,
                                momentum=0.9)
    ''' setup tensorboard '''
    writer = SummaryWriter(os.path.join(args.save_dir, 'train_info'))

    print('===> start training ...')
    iters = 0
    best_acc = 0
Ejemplo n.º 4
0
def make_fused_weights(weight_path):
    save_path = weight_path[:-4] + '_fused.pth'

    checkpoint = torch.load(weight_path)
    if 'model_state_dict' in checkpoint.keys():
        weights = checkpoint['model_state_dict']
    else:
        weights = checkpoint

    net = model.Resnet50()
    net_fused = model_fused.Resnet50()

    new_weights = net_fused.state_dict()

    weight_dict = {}
    running_mean_list = []
    bias_list = []

    # weight dict
    weight_num = 0
    for key in weights.keys():
        if key[-6:] == 'weight':
            weight_dict[key[:-7]] = weight_num
            weight_num += 1
        elif key[-12:] == 'running_mean':
            running_mean_list.append(key[:-13])

    # new weight dict
    for key in weights.keys():
        if key[-4:] == 'bias':
            if key[:-5] not in running_mean_list:
                bias_list.append(key[:-5])

    bn_num = 0
    for bn_name in running_mean_list:
        weight_num = weight_dict[bn_name] - 1
        conv_name = list(weight_dict.keys())[weight_num]
        #        new_conv_name = list(new_weight_dict.keys())[weight_num - bn_num]

        conv_weight, conv_bias = batch_concate(weights,
                                               conv_name,
                                               bn_name,
                                               is_tconv=False)

        new_weights[conv_name + '.weight'] = conv_weight
        new_weights[conv_name + '.bias'] = conv_bias

        print(conv_name, bn_name)

        bn_num += 1

    for bias_name in bias_list:
        print(bias_name, bias_name)
        new_weights[bias_name + '.weight'] = weights[bias_name +
                                                     '.weight'].double()
        new_weights[bias_name + '.bias'] = weights[bias_name +
                                                   '.bias'].double()

    # compare_result
    net.load_state_dict(weights)
    net_fused.load_state_dict(new_weights)

    net.eval()
    net_fused.eval()

    # testing
    bn_num = 0
    for bn_name in running_mean_list:
        weight_num = weight_dict[bn_name] - 1
        conv_name = list(weight_dict.keys())[weight_num]

        conv_path = conv_name.split('.')

        bn_num += 1

    with torch.no_grad():
        x = torch.rand((1, 3, 224, 224))
        output = net(x)
        output_fused = net_fused(x)

    print(output.shape)
    print(output_fused.shape)
    #    print(output - output_fused)
    print(abs(output - output_fused).max())

    torch.save({'model_state_dict': new_weights}, save_path)
Ejemplo n.º 5
0
                                    resize=config.resize,
                                    crop=config.crop,
                                    batch_size=config.trn_batch_size,
                                    num_workers=config.num_workers,
                                    type='classifier_train')

tst_dloader, tst_dlen = data_loader(dataset_root=config.dataset_root,
                                    resize=config.resize,
                                    crop=config.crop,
                                    batch_size=config.tst_batch_size,
                                    num_workers=config.num_workers,
                                    type='classifier_test')

# Build models
print('[2 / 2]. Build models.. \n')
encoder = nn.DataParallel(model.Resnet50(dim=config.out_dim)).to(dev)

ckpt_name = 'ckpt_' + str(config.load_pretrained_epoch) + '.pkl'
ckpt_path = os.path.join(config.encoder_output_root,
                         config.encoder_dataset_name,
                         config.encoder_exp_version, 'weight', ckpt_name)
ckpt = torch.load(ckpt_path)
encoder.load_state_dict(ckpt['encoder'])

feature_extractor = nn.Sequential(*list(
    encoder.module.resnet.children())[:-1])  # feature extractor from encoder
linear = nn.Linear(config.in_dim, config.cls_num).to(dev)  # linear classifier

# Freeze encoder
for param in feature_extractor.parameters():
    param.requires_grad = False
    print('===> prepare dataloader ...')
    train_loader = torch.utils.data.DataLoader(data.TrimmedVideoData(
        args, mode='train', model_type='cnn'),
                                               batch_size=args.train_batch,
                                               num_workers=args.workers,
                                               shuffle=True)
    #    collate_fn=data.collate_fn)

    val_loader = torch.utils.data.DataLoader(data.TrimmedVideoData(
        args, mode='val', model_type='cnn'),
                                             batch_size=args.train_batch,
                                             num_workers=args.workers,
                                             shuffle=True)
    ''' load model '''
    print('===> prepare model ...')
    feature_extractor, classifier = model.Resnet50(), model.Classifier()

    feature_extractor.cuda()
    classifier.cuda()

    clf_criterion = nn.CrossEntropyLoss()
    clf_optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, args.beta2))
    ''' setup tensorboard '''
    writer = SummaryWriter(os.path.join(args.save_dir, 'train_info'))

    print('===> start training ...')
    iters = 0
    best_acc = 0