Esempio n. 1
0
def main(arg):
    test_dataset = HagglingDataset(FLAGS.test, FLAGS)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=10)

    ckpt = FLAGS.ckpt_dir + FLAGS.model

    model = get_model()
    model.load_model(ckpt, FLAGS.test_ckpt)
    model.eval()

    metrics = Metrics(FLAGS)

    df = pd.DataFrame()

    with torch.no_grad():
        for i_batch, batch in enumerate(test_dataloader):

            batch_runs = FLAGS.batch_runs
            if FLAGS.VAE:
                batch_runs = FLAGS.batch_runs

            for test_num in range(0, batch_runs):
                predictions, targets = model(batch)

                out = metrics.compute_and_save(predictions, targets, batch,
                                               i_batch, test_num)

                print(out)

                df = df.append(out, ignore_index=True)

        df_mean = df.mean(axis=0)
        df_std = df.std(axis=0)
        print(df_mean)
        print(df_std)
        df_mean.to_csv('testResults/' + FLAGS.model + '/mean.csv')
        df_std.to_csv('testResults/' + FLAGS.model + '/std.csv')
Esempio n. 2
0
def main(args):
    # make sure dec hidden units and layers are same
    FLAGS.dec_hidden_units = FLAGS.enc_hidden_units
    FLAGS.dec_layers = FLAGS.enc_layers

    # initialize the dataset and the data loader
    train_dataset = HagglingDataset(FLAGS.train, FLAGS)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=FLAGS.batch_size,
                                  shuffle=True,
                                  num_workers=10)
    test_dataset = HagglingDataset(FLAGS.test, FLAGS)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 shuffle=False,
                                 num_workers=10)

    # set the wandb config
    config = FLAGS.flag_values_dict()
    run = wandb.init(project="Sell-It", config=config)

    # initialize the model, log it for visualization
    model = get_model()
    # try:
    #     torch.onnx.export(model, next(iter(train_dataloader)),
    #                       os.path.join(FLAGS.ckpt_dir, FLAGS.model + '/model.onnx'))
    #     wandb.save(os.path.join(FLAGS.ckpt_dir, FLAGS.model + '/model.onnx'))
    # except Exception as e:
    #     print(e)

    starting_epoch = 0

    # restore model if needed
    if FLAGS.resume_train:
        ckpt = os.path.join(FLAGS.ckpt_dir, FLAGS.model + '/')
        starting_epoch = model.load_model(ckpt, None)
        #starting_epoch = 130

    # get the loss function and optimizers
    criterion = get_loss_fn()
    optimizer = get_optimizer(model.get_trainable_parameters())
    p = 1.0

    metrics = Metrics(FLAGS)

    # run the training script
    for epoch in range(starting_epoch + 1, FLAGS.epochs + 1):

        print(epoch)

        # initialize the total epoch loss values
        train_loss_logs = {
            'Train/Total_Loss': 0,
            'Train/Reconstruction_Loss': 0,
            'Train/Regularization_Loss': 0,
            'Train/CrossEntropy_Loss': 0,
            'Train/VelocityRegularization': 0
        }

        train_metric_logs = {
            'Train/RightMSE': 0,
            'Train/LeftMSE': 0,
            'Train/RightNPSS': 0,
            'Train/LeftNPSS': 0,
            'Train/RightFrechet': 0,
            'Train/LeftFrechet': 0,
            'Train/RightSpeech': 0,
            'Train/LeftSpeech': 0,
            'Train/MSE': 0,
            'Train/NPSS': 0,
            'Train/Frechet': 0,
            'Train/Speech': 0
        }

        test_metric_logs = {
            'Test/RightMSE': 0,
            'Test/LeftMSE': 0,
            'Test/RightNPSS': 0,
            'Test/LeftNPSS': 0,
            'Test/RightFrechet': 0,
            'Test/LeftFrechet': 0,
            'Test/RightSpeech': 0,
            'Test/LeftSpeech': 0,
            'Test/MSE': 0,
            'Test/NPSS': 0,
            'Test/Frechet': 0,
            'Test/Speech': 0
        }

        # set model to train mode
        model.train()

        # decay factor set
        decay_p(p, epoch, model)

        # run through all the batches
        for i_batch, batch in enumerate(train_dataloader):
            # zero prev gradients
            optimizer.zero_grad()

            # forward pass through the net
            predictions, targets = model(batch)

            # calculate loss
            losses = criterion(predictions, targets, model.parameters(), FLAGS)
            total_loss = losses['Total_Loss']

            # calculate gradients
            total_loss.backward()
            optimizer.step()

            # compute train metrics
            with torch.no_grad():
                if not FLAGS.skip_train_metrics:
                    train_metrics = metrics.compute_and_save(
                        predictions, targets, batch, i_batch, None)
                    train_metric_logs = {
                        'Train/' + key:
                        train_metrics[key] + train_metric_logs['Train/' + key]
                        for key in train_metrics
                    }

                train_loss_logs = {
                    'Train/' + key: losses[key].detach().cpu().numpy().item() +
                    train_loss_logs['Train/' + key]
                    for key in losses
                }

        # set the model to evaluation mode
        model.eval()

        # calculate validation loss
        with torch.no_grad():
            for i_batch, batch in enumerate(test_dataloader):
                # forward pass through the net
                predictions, targets = model(batch)

                if FLAGS.model == 'bodyAE' or FLAGS.model == 'bmg':
                    test_metric_logs['Test/MSE'] += meanJointPoseError(
                        predictions, targets)
                else:
                    # consolidate metrics
                    test_metrics = metrics.compute_and_save(
                        predictions, targets, batch, i_batch, None)
                    test_metric_logs = {
                        'Test/' + key:
                        test_metrics[key] + test_metric_logs['Test/' + key]
                        for key in test_metrics
                    }

        # scale the metrics
        train_metric_logs = {
            key: train_metric_logs[key] / len(train_dataloader)
            for key in train_metric_logs
        }
        train_loss_logs = {
            key: train_loss_logs[key] / len(train_dataloader)
            for key in train_loss_logs
        }
        test_metric_logs = {
            key: test_metric_logs[key] / len(test_dataloader)
            for key in test_metric_logs
        }

        # log all the metrics
        run.log({**train_metric_logs, **train_loss_logs, **test_metric_logs})

        if epoch % FLAGS.ckpt == 0 and epoch > 0:
            ckpt = os.path.join(FLAGS.ckpt_dir,
                                FLAGS.model + '/' + wandb.run.name + '/')
            model.save_model(ckpt, epoch)

    run.finish()