Esempio n. 1
0
def load_checkpoint(context):
    args = context['args']
    model = context['model']
    if args.resume == '':
        return
    t_saver = Saver(model_dir=args.resume)
    print('==> loading checkpoint from {}'.format(args.resume))
    if args.evaluate:
        checkpoint = t_saver.load_best()
    else:
        checkpoint = t_saver.load_latest()
    if checkpoint:
        best_metric = checkpoint['best_metric']
        context['best_metric'] = best_metric
        model.load_state_dict(checkpoint['model_state_dict'])
        if 'step' in checkpoint:
            step = checkpoint['step']
        else:
            step = 0
        if args.step != -1:
            step = args.step
        print("==> loaded checkpoint {} (step {}, best_metric {})".format(args.resume,
                                                                          step, best_metric))

        context['step'] = step
    else:
        raise RuntimeError("==> no checkpoint at: {}".format(args.resume))
Esempio n. 2
0
def main():
    global step, best_prec1, svm_layer, rgbCNN, flowCNN, crit, optimizer, saver, writer
    global train_dataset, test_dataset, train_loader, test_loader, vocab

    print('prepare dataset...')
    (train_dataset, train_loader), (test_dataset,
                                    test_loader) = prepare_dataset()

    # prepare model

    svm_layer, rgbCNN, flowCNN, crit, optimizer = prepare_model()
    if args.resume != '':
        t_saver = Saver(model_dir=args.resume, max_to_keep=5)
        print("=> loading checkpoint '{}'".format(args.resume))
        if args.evaluate:
            checkpoint = t_saver.load_best()
            # checkpoint = t_saver.load_latest()
        else:
            checkpoint = t_saver.load_latest()
        if checkpoint is not None:
            # torch.load(args.resume)
            best_prec1 = checkpoint['best_prec1']
            svm_layer.load_state_dict(checkpoint['model_state_dict'])
            if 'step' in checkpoint:
                step = checkpoint['step']
            else:
                step = 0
            if args.step != -1:
                step = args.step

            print("=> loaded checkpoint '{}' (step {})".format(
                args.resume, checkpoint['step']))
        else:
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))

    if args.evaluate:
        validate(True)
        return

    print('prepare logger...')
    # writer, saver = prepare_logger()
    print('start training...')
    with torch.no_grad():
        feature_Matrix, labels = train()

    np.save('train_feature_list_hmdb.npy', np.array(feature_Matrix))
    np.save('train_labels_list_hmdb.npy', np.array(labels))
    print('done!')