コード例 #1
0
def load_model(log_dir, cp_num, n_classes=5994):
    model = background_resnet(num_classes=n_classes)
    print('=> loading checkpoint')
    checkpoint = torch.load(log_dir + '/checkpoint_' + str(cp_num).zfill(3) + '.pth')

    # create new OrderedDict that does not contain `module.`
    model.load_state_dict(checkpoint['state_dict'])

    return model
コード例 #2
0
def load_model(use_cuda, log_dir, cp_num, n_classes):
    model = background_resnet(num_classes=n_classes)

    if use_cuda:
        model.cuda()
    print('=> loading checkpoint')

    checkpoint = torch.load(log_dir + '/checkpoint_' + str(cp_num).zfill(3) +
                            '.pth')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    return model
コード例 #3
0
def load_model(use_cuda, log_dir, cp_num, n_classes):
    model = background_resnet(num_classes=n_classes)

    if use_cuda:
        model.cuda()
    print('=> loading checkpoint')
    # original saved file with DataParallel
    checkpoint = torch.load(log_dir + '/checkpoint_' + str(cp_num).zfill(3) + '.pth')
    # create new OrderedDict that does not contain `module.`
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    return model
コード例 #4
0
ファイル: train.py プロジェクト: wuqiangch/meta-SR
def load_model(log_dir, cp_num, n_classes):
    model = background_resnet(num_classes=n_classes)
    optimizer = create_optimizer(model)

    print('=> loading checkpoint')
    checkpoint = torch.load(log_dir + '/checkpoint_' + str(cp_num).zfill(3) + '.pth')

    # create new OrderedDict that does not contain `module.`
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()
    return model, optimizer
コード例 #5
0
ファイル: train.py プロジェクト: wuqiangch/meta-SR
def main():

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
        if args.use_cuda:
            torch.cuda.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.use_checkpoint: start = args.cp_num + 1
    else: start = 0  # Start epoch
    n_epochs = max_epoch - start  # How many epochs?

    # Load dataset
    train_DB, n_data, n_classes = make_DB(DB_type=args.data_type)
    n_episode = int(n_data / ((args.n_shot + args.n_query) * args.nb_class_train))

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Generate model and optimizer
    if args.use_checkpoint:
      model, optimizer = load_model(log_dir, args.cp_num, n_classes)
    else:
      model = background_resnet(num_classes=n_classes)
      optimizer = create_optimizer(model)

    # define objective function, optimizer and scheduler
    objective = Prototypical() if args.loss_type == 'prototypical' else SoftmaxLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, min_lr=1e-5, threshold=1e-4, verbose=1)

    if args.use_cuda:
        model.cuda()

    train_generator = metaGenerator(train_DB, read_MFB,
                                    nb_classes=args.nb_class_train, nb_samples_per_class=args.n_shot + args.n_query,
                                    max_iter=n_episode * (n_epochs-args.cp_num), xp=np)
    # training
    train(train_generator, model, objective, optimizer, n_episode, log_dir, scheduler)