dropout_p=args.dropout_p) model.to(device) ################ # RUN TRAINING # ################ trainer = autoregressive_train.AutoregressiveTrainer( model=model, data_loader=loader, params=trainer_params, snapshot_path=working_dir + '/snapshots', snapshot_name=args.run_name, snapshot_interval=args.num_iterations // 10, snapshot_exec_template=sbatch_executable, device=device, # logger=model_logging.Logger(validation_interval=None), logger=model_logging.TensorboardLogger( log_interval=500, validation_interval=1000, generate_interval=5000, log_dir=working_dir + '/logs/' + args.run_name, print_output=True, )) if args.restore is not None: trainer.load_state(checkpoint) print() print("Model:", model.__class__.__name__) print("Hyperparameters:", json.dumps(model.hyperparams, indent=4)) print("Trainer:", trainer.__class__.__name__)
reader.load_autoregressive_fr(model) else: checkpoint = torch.load(os.path.join('../snapshots', args.restore), map_location='cpu') dims = checkpoint['model_dims'] hyperparams = checkpoint['model_hyperparams'] model = autoregressive_model.AutoregressiveFR(dims=dims, hyperparams=hyperparams, dropout_p=args.dropout_p) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) print("Num parameters:", model.parameter_count()) trainer = autoregressive_train.AutoregressiveTrainer( model=model, data_loader=None, device=device, ) output = trainer.test(loader, model_eval=False, num_samples=args.num_samples, return_logits=args.save_logits, return_ce=args.save_ce) if args.save_logits or args.save_ce: output, logits = output logits_path = os.path.splitext(args.output)[0] os.makedirs(logits_path, exist_ok=True) for key, value in logits.items(): np.save(f"{logits_path}/{key}.npy", value) output = pd.DataFrame(output, columns=output.keys())