def load_checkpoint(context): args = context['args'] model = context['model'] if args.resume == '': return t_saver = Saver(model_dir=args.resume) print('==> loading checkpoint from {}'.format(args.resume)) if args.evaluate: checkpoint = t_saver.load_best() else: checkpoint = t_saver.load_latest() if checkpoint: best_metric = checkpoint['best_metric'] context['best_metric'] = best_metric model.load_state_dict(checkpoint['model_state_dict']) if 'step' in checkpoint: step = checkpoint['step'] else: step = 0 if args.step != -1: step = args.step print("==> loaded checkpoint {} (step {}, best_metric {})".format(args.resume, step, best_metric)) context['step'] = step else: raise RuntimeError("==> no checkpoint at: {}".format(args.resume))
def main(): global step, best_prec1, svm_layer, rgbCNN, flowCNN, crit, optimizer, saver, writer global train_dataset, test_dataset, train_loader, test_loader, vocab print('prepare dataset...') (train_dataset, train_loader), (test_dataset, test_loader) = prepare_dataset() # prepare model svm_layer, rgbCNN, flowCNN, crit, optimizer = prepare_model() if args.resume != '': t_saver = Saver(model_dir=args.resume, max_to_keep=5) print("=> loading checkpoint '{}'".format(args.resume)) if args.evaluate: checkpoint = t_saver.load_best() # checkpoint = t_saver.load_latest() else: checkpoint = t_saver.load_latest() if checkpoint is not None: # torch.load(args.resume) best_prec1 = checkpoint['best_prec1'] svm_layer.load_state_dict(checkpoint['model_state_dict']) if 'step' in checkpoint: step = checkpoint['step'] else: step = 0 if args.step != -1: step = args.step print("=> loaded checkpoint '{}' (step {})".format( args.resume, checkpoint['step'])) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) if args.evaluate: validate(True) return print('prepare logger...') # writer, saver = prepare_logger() print('start training...') with torch.no_grad(): feature_Matrix, labels = train() np.save('train_feature_list_hmdb.npy', np.array(feature_Matrix)) np.save('train_labels_list_hmdb.npy', np.array(labels)) print('done!')