Esempio n. 1
0
        cnn_optimizer.step()

        xentropy_loss_avg += xentropy_loss.item()

        # Calculate running average of accuracy
        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels.data).sum().item()
        accuracy = correct / total

        progress_bar.set_postfix(xentropy='%.3f' % (xentropy_loss_avg /
                                                    (i + 1)),
                                 acc='%.3f' % accuracy)

    test_acc = test(test_loader)
    flops = cnn.flops()
    score = flops / mobilenet_flops + params / mobilenet_params
    tqdm.write('test_acc: %.3f, flops: %s, parameters: %s, score %s' %
               (test_acc, flops, params, score))
    scheduler.step(epoch)

    row = {
        'epoch': str(epoch),
        'train_acc': str(accuracy),
        'test_acc': str(test_acc)
    }
    csv_logger.writerow(row)

torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt')
csv_logger.close()
Esempio n. 2
0
        progress_bar.set_postfix(
            xentropy='%.3f' % (xentropy_loss_avg / (i + 1)),
            acc='%.3f' % accuracy)

    test_acc = test(test_loader)
    tqdm.write('test_acc: %.3f' % (test_acc))

    # scheduler.step(epoch)  # Use this line for PyTorch <1.4
    scheduler.step()     # Use this line for PyTorch >=1.4

    row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc)}
    csv_logger.writerow(row)

checkpoints_dir = 'checkpoints'
os.makedirs(checkpoints_dir, exist_ok=True)

torch.save(cnn.state_dict(),
           os.path.join(checkpoints_dir, '{}.pt'.format(test_id)))
dummy_input = torch.randn((20, 3, args.input_res, args.input_res),
                          requires_grad=False)
onnx_path = os.path.join('checkpoints', '{}.onnx'.format(test_id))

torch.onnx.export(cnn.cpu(),
                  dummy_input,
                  onnx_path,
                  opset_version=10,
                  do_constant_folding=True,
                  keep_initializers_as_inputs=True)
csv_logger.close()
Esempio n. 3
0
        xentropy_loss_avg += xentropy_loss.item()

        # Calculate running average of accuracy
        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels.data).sum().item()
        accuracy = correct / total

        progress_bar.set_postfix(xentropy='%.3f' % (xentropy_loss_avg /
                                                    (i + 1)),
                                 acc='%.3f' % accuracy)

    test_acc = eval(cnn, test_loader)
    if test_acc > max_acc:
        max_acc = test_acc
        torch.save(cnn.state_dict(), basic_path + 'checkpoints/max_acc.pt')

    scheduler.step(epoch)

    tqdm.write('test_acc: %.4f, max_acc: %.4f' % (test_acc, max_acc))
    csv_logger.writerow({
        'epoch': str(epoch),
        'train_acc': str(accuracy),
        'test_acc': str(test_acc),
        'max_acc': str(max_acc)
    })

torch.save(cnn.state_dict(), basic_path + 'checkpoints/last.pt')
csv_logger.close()
Esempio n. 4
0
            'loss': '%.5f' % (sum(loss_avg.values())),
        }
        loss_avg = {k: '%.5f' % loss_avg[k] for k in loss_avg}
        row.update(loss_avg)
        row.update({
            'time':
            format_time(time.time() - st_time),
            'eta':
            format_time((time.time() - st_time) / (epoch + 1) *
                        (args.epochs - epoch - 1)),
        })
        print(row)
        logger.writerow(row)
    ##end for epoch
    torch.save({
        'model': cnn.state_dict(),
    }, ckpt_cnn_filename)
    logger.close()
else:
    print("\n Loading pretrained model...")
    cnn.load_state_dict(torch.load(ckpt_cnn_filename)['model'])
    cnn = cnn.cuda()

val_acc = test(test_loader) * 100.0
print(val_acc)
with open(log_filename, 'a') as f:
    f.write(
        "\n==================================================================================================="
    )

eval_results_fullpath = ckpt_directory + "/test_result_" + test_id + ".txt"
Esempio n. 5
0
        'test_acc': str(test_acc)
    }
    csv_logger.writerow(row)

    if test_acc > best_acc:

        best_acc = test_acc
        best_epoch = epoch
        if test_acc > 0.1:
            best_model = copy.deepcopy(cnn)

        # keep csv of predictions
        best_preds = df_preds
        preds_dir = 'preds/' + test_id
        os.makedirs(preds_dir, exist_ok=True)
        df_preds.to_csv(
            f'{preds_dir}/predictions_test_epoch{epoch}_acc{best_acc}.csv')

print("Best test acc:", str(best_acc))
print("Best Epoch:", str(best_epoch))
newdir = 'new_checkpoints/' + test_id + "/"
os.makedirs(newdir, exist_ok=True)
torch.save(cnn.state_dict(), newdir + str(args.epochs) + 'epoch' + '.pt')
torch.save(best_model.state_dict(),
           newdir + 'best_model_at_epoch_' + str(best_epoch) + '.pt')

csv_logger.close()
best_preds.to_csv(
    f'{preds_dir}/best_predictions_test_epoch{epoch}_acc{best_acc}.csv',
    index=False)
Esempio n. 6
0
        accuracy = correct / total

        progress_bar.set_postfix(xentropy='%.3f' % (xentropy_loss_avg /
                                                    (i + 1)),
                                 acc='%.3f' % accuracy)

    test_acc = test(test_loader)
    tqdm.write('test_acc: %.3f' % (test_acc))

    scheduler.step(epoch)

    row = {
        'epoch': str(epoch),
        'train_acc': str(accuracy),
        'test_acc': str(test_acc)
    }
    csv_logger.writerow(row)
    #Saving checkpoint
    is_best = test_acc > max_accuracy
    max_accuracy = max(test_acc, max_accuracy)
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'state_dict': cnn.state_dict(),
            'best_prec1': max_accuracy,
        }, is_best)

print("Best Accuracy", max_accuracy)
#torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt')
csv_logger.close()
Esempio n. 7
0
    test_acc = test(test_loader)
    tqdm.write('test_acc: %.3f' % (test_acc))

    scheduler.step()

    row = {
        'epoch': str(epoch),
        'train_acc': str(accuracy),
        'test_acc': str(test_acc),
    }

    csv_logger.writerow(row)

    if args.tensorboard:
        global_step = epoch + 1
        writer.add_scalar('xentropy loss', xentropy_loss_avg / (i + 1),
                          global_step=global_step)
        writer.add_scalar('train acc', accuracy, global_step=global_step)
        writer.add_scalar('test acc', test_acc, global_step=global_step)


# Save model checkpoint.
checkpoint_path = misc_util.get_checkpoint_path(args.out_dir)
misc_util.create_directory(os.path.dirname(checkpoint_path))
torch.save(cnn.state_dict(), checkpoint_path)

# Close logger and SummaryWriter.
csv_logger.close()
if args.tensorboard:
    writer.close()
Esempio n. 8
0
        val_acc = 0

    scheduler.step(epoch)

    row = {
        'epoch': str(epoch),
        'train_acc': str(accuracy),
        'val_acc': str(val_acc)
    }
    csv_logger.writerow(row)

    acc_train[epoch] = accuracy
    acc_val[epoch] = val_acc
    loss_train[epoch] = xentropy_loss_avg

torch.save(cnn.state_dict(), model_weights_file)
csv_logger.close()
print("Training took {}".format(datetime.now() - start_time))

# Plot Accuracies/losses
fig, ax_arr = plt.subplots(1, 2)
ax_arr[0].plot(np.arange(args.epochs),
               loss_train / i,
               label='train',
               color='b')
ax_arr[0].set_xlabel("Epoch")
ax_arr[0].set_ylabel("Loss")

ax_arr[1].plot(np.arange(args.epochs), acc_train, label='train', color='b')
ax_arr[1].plot(np.arange(args.epochs), acc_val, label='validation', color='r')
ax_arr[1].set_xlabel("Epoch")