def main(arg): test_dataset = HagglingDataset(FLAGS.test, FLAGS) test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=10) ckpt = FLAGS.ckpt_dir + FLAGS.model model = get_model() model.load_model(ckpt, FLAGS.test_ckpt) model.eval() metrics = Metrics(FLAGS) df = pd.DataFrame() with torch.no_grad(): for i_batch, batch in enumerate(test_dataloader): batch_runs = FLAGS.batch_runs if FLAGS.VAE: batch_runs = FLAGS.batch_runs for test_num in range(0, batch_runs): predictions, targets = model(batch) out = metrics.compute_and_save(predictions, targets, batch, i_batch, test_num) print(out) df = df.append(out, ignore_index=True) df_mean = df.mean(axis=0) df_std = df.std(axis=0) print(df_mean) print(df_std) df_mean.to_csv('testResults/' + FLAGS.model + '/mean.csv') df_std.to_csv('testResults/' + FLAGS.model + '/std.csv')
def main(args): # make sure dec hidden units and layers are same FLAGS.dec_hidden_units = FLAGS.enc_hidden_units FLAGS.dec_layers = FLAGS.enc_layers # initialize the dataset and the data loader train_dataset = HagglingDataset(FLAGS.train, FLAGS) train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=10) test_dataset = HagglingDataset(FLAGS.test, FLAGS) test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=10) # set the wandb config config = FLAGS.flag_values_dict() run = wandb.init(project="Sell-It", config=config) # initialize the model, log it for visualization model = get_model() # try: # torch.onnx.export(model, next(iter(train_dataloader)), # os.path.join(FLAGS.ckpt_dir, FLAGS.model + '/model.onnx')) # wandb.save(os.path.join(FLAGS.ckpt_dir, FLAGS.model + '/model.onnx')) # except Exception as e: # print(e) starting_epoch = 0 # restore model if needed if FLAGS.resume_train: ckpt = os.path.join(FLAGS.ckpt_dir, FLAGS.model + '/') starting_epoch = model.load_model(ckpt, None) #starting_epoch = 130 # get the loss function and optimizers criterion = get_loss_fn() optimizer = get_optimizer(model.get_trainable_parameters()) p = 1.0 metrics = Metrics(FLAGS) # run the training script for epoch in range(starting_epoch + 1, FLAGS.epochs + 1): print(epoch) # initialize the total epoch loss values train_loss_logs = { 'Train/Total_Loss': 0, 'Train/Reconstruction_Loss': 0, 'Train/Regularization_Loss': 0, 'Train/CrossEntropy_Loss': 0, 'Train/VelocityRegularization': 0 } train_metric_logs = { 'Train/RightMSE': 0, 'Train/LeftMSE': 0, 'Train/RightNPSS': 0, 'Train/LeftNPSS': 0, 'Train/RightFrechet': 0, 'Train/LeftFrechet': 0, 'Train/RightSpeech': 0, 'Train/LeftSpeech': 0, 'Train/MSE': 0, 'Train/NPSS': 0, 'Train/Frechet': 0, 'Train/Speech': 0 } test_metric_logs = { 'Test/RightMSE': 0, 'Test/LeftMSE': 0, 'Test/RightNPSS': 0, 'Test/LeftNPSS': 0, 'Test/RightFrechet': 0, 'Test/LeftFrechet': 0, 'Test/RightSpeech': 0, 'Test/LeftSpeech': 0, 'Test/MSE': 0, 'Test/NPSS': 0, 'Test/Frechet': 0, 'Test/Speech': 0 } # set model to train mode model.train() # decay factor set decay_p(p, epoch, model) # run through all the batches for i_batch, batch in enumerate(train_dataloader): # zero prev gradients optimizer.zero_grad() # forward pass through the net predictions, targets = model(batch) # calculate loss losses = criterion(predictions, targets, model.parameters(), FLAGS) total_loss = losses['Total_Loss'] # calculate gradients total_loss.backward() optimizer.step() # compute train metrics with torch.no_grad(): if not FLAGS.skip_train_metrics: train_metrics = metrics.compute_and_save( predictions, targets, batch, i_batch, None) train_metric_logs = { 'Train/' + key: train_metrics[key] + train_metric_logs['Train/' + key] for key in train_metrics } train_loss_logs = { 'Train/' + key: losses[key].detach().cpu().numpy().item() + train_loss_logs['Train/' + key] for key in losses } # set the model to evaluation mode model.eval() # calculate validation loss with torch.no_grad(): for i_batch, batch in enumerate(test_dataloader): # forward pass through the net predictions, targets = model(batch) if FLAGS.model == 'bodyAE' or FLAGS.model == 'bmg': test_metric_logs['Test/MSE'] += meanJointPoseError( predictions, targets) else: # consolidate metrics test_metrics = metrics.compute_and_save( predictions, targets, batch, i_batch, None) test_metric_logs = { 'Test/' + key: test_metrics[key] + test_metric_logs['Test/' + key] for key in test_metrics } # scale the metrics train_metric_logs = { key: train_metric_logs[key] / len(train_dataloader) for key in train_metric_logs } train_loss_logs = { key: train_loss_logs[key] / len(train_dataloader) for key in train_loss_logs } test_metric_logs = { key: test_metric_logs[key] / len(test_dataloader) for key in test_metric_logs } # log all the metrics run.log({**train_metric_logs, **train_loss_logs, **test_metric_logs}) if epoch % FLAGS.ckpt == 0 and epoch > 0: ckpt = os.path.join(FLAGS.ckpt_dir, FLAGS.model + '/' + wandb.run.name + '/') model.save_model(ckpt, epoch) run.finish()