Пример #1
0
def get_models(args,
               BERT_PT_PATH,
               trained=False,
               path_model_bert=None,
               path_model=None):
    # some constants
    agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
    cond_ops = ['=', '>', '<', 'OP']  # do not know why 'OP' required. Hence,

    print(f"Batch_size = {args.bS * args.accumulate_gradients}")
    print(f"Accumulate_gradients = {args.accumulate_gradients}")
    print(f"BERT parameters:")
    print(f"learning rate: {args.lr_bert}")
    print(f"Fine-tune BERT: {args.fine_tune}")

    # Get BERT
    model_bert, tokenizer, bert_config = get_bert(BERT_PT_PATH, args.bert_type,
                                                  args.do_lower_case,
                                                  args.my_pretrain_bert)
    args.iS = bert_config.hidden_size * args.num_target_layers  # Seq-to-SQL input vector dimenstion

    # Get Seq-to-SQL

    n_cond_ops = len(cond_ops)
    n_agg_ops = len(agg_ops)
    print(
        f"Seq-to-SQL: the number of final BERT layers to be used: {args.num_target_layers}"
    )
    print(f"Seq-to-SQL: the size of hidden dimension = {args.hS}")
    print(f"Seq-to-SQL: LSTM encoding layer size = {args.lS}")
    print(f"Seq-to-SQL: dropout rate = {args.dr}")
    print(f"Seq-to-SQL: learning rate = {args.lr}")
    model = Seq2SQL_v1(args.iS, args.hS, args.lS, args.dr, n_cond_ops,
                       n_agg_ops)
    model = model.to(device)

    if trained:
        assert path_model_bert != None
        assert path_model != None

        if torch.cuda.is_available():
            res = torch.load(path_model_bert)
        else:
            res = torch.load(path_model_bert, map_location='cpu')
        model_bert.load_state_dict(res['model_bert'])
        model_bert.to(device)

        if torch.cuda_is_available():
            res = torch.load(path_model)
        else:
            res = torch.load(path_model, map_location='cpu')

        model.load_state_dict(res['model'])

    return model, model_bert, tokenizer, bert_config
 def __init__(self, encoder: Encoder, decoder: DecoderPythonCRF, entries: EntriesProcessor,teacher_forcing_ratio = 0.5, learning_rate=0.01,
              max_input_length=40, max_output_length=20, device=None):
     self.encoder = encoder
     self.decoder = decoder
     self.entries = entries
     self.teacher_forcing_ratio = teacher_forcing_ratio
     self.encoder_optimizer = optim.Adam(encoder.parameters())
     self.decoder_optimizer = optim.Adam(decoder.parameters())
     self.max_input_length = max_input_length
     self.max_output_length = max_output_length
     if device is None:
         self.device = torch.device("cuda" if torch.cuda_is_available() else "cpu")
     else:
         self.device = device
Пример #3
0
def get_vgg(cuda=False):
    vgg = models.vgg19(pretrained=True)
    vgg_features = vgg.features
    """
    We freeze the parameters because in style transfer, we only want to use the VGG19 model as a feature
    extractor, not to train it via backpropagation. Instead, backprop will aim to minimize the loss function
    that compares the content/style representations of the target image with the output image
    """
    for param in vgg.parameters():
        param.requires_grad = False
    if cuda:
        device = torch.device("cuda" if torch.cuda_is_available() else "cpu")
        vgg_features.to(device)

    return vgg_features
Пример #4
0
def get_models(config,
               BERT_PATH,
               trained=False,
               path_model_bert=None,
               path_model=None):
    # some constants
    agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
    cond_ops = ['=', '>', '<', 'OP']  # do not know why 'OP' required. Hence,

    # Get BERT
    model_bert, tokenizer, bert_config = get_bert(BERT_PATH)
    input_size = bert_config.hidden_size * config[
        "num_target_layers"]  # Seq-to-SQL input vector dimenstion

    # Get Seq-to-SQL

    number_cond_ops = len(cond_ops)
    number_agg_ops = len(agg_ops)
    model = Seq2SQL_v1(input_size=input_size,
                       hidden_size=100,
                       num_layer=2,
                       dropout=config["dropout"],
                       number_cond_ops=number_cond_ops,
                       number_agg_ops=number_agg_ops)
    model = model.to(device)

    if trained:
        assert path_model_bert != None
        assert path_model != None

        if torch.cuda.is_available():
            res = torch.load(path_model_bert)
        else:
            res = torch.load(path_model_bert, map_location='cpu')
        model_bert.load_state_dict(res['model_bert'])
        model_bert.to(device)

        if torch.cuda_is_available():
            res = torch.load(path_model)
        else:
            res = torch.load(path_model, map_location='cpu')

        model.load_state_dict(res['model'])

    return model, model_bert, tokenizer, bert_config
Пример #5
0
    def __init__(self,
                 actor_id,
                 replay_buffer,
                 parameter_server,
                 config,
                 epsilon,
                 eval=False):

        self.actor_id = actor_id
        self.replay_buffer = replay_buffer
        self.parameter_server = parameter_server
        self.config = config
        self.epsilon = epsilon
        self.eval = eval

        self.device = torch.device("cpu")
        if config["eval_device"] == "gpu":
            if torch.cuda_is_available():
                self.device = torch.device("cuda:0")

        self.observation_shape = config["observation_shape"]
        self.num_actions = config["num_actions"]
        self.multi_step_n = config.get("n_step", 1)
        self.q_update_freq = config.get("q_update_freq", 100)
        self.send_experience_freq = config.get("send_experience_freq", 100)

        self.q_net = get_q_network(config)
        if self.eval:
            self.q_net.eval()
        else:
            self.q_net.train()
        self.q_net.to(self.device)

        self.env = gym.make(config["env"])
        self.local_buffer = []

        self.continue_sampling = True
        self.cur_episodes = 0
        self.cur_steps = 0
Пример #6
0
    def __init__(self, config, replay_buffer, parameter_server):
        self.config = config
        self.replay_buffer = replay_buffer
        self.parameter_server = parameter_server
        self.target_network_update_interval = config.get(
            "target_network_update_interval", 32)

        self.device = torch.device("cpu")
        if config["eval_device"] == "gpu":
            if torch.cuda_is_available():
                self.device = torch.device("cuda:0")

        self.q_net = get_q_network(config).train().to(self.device)
        self.target_q_net = copy.deepcopy(self.q_net).to(self.device)

        self.train_batch_size = config["train_batch_size"]
        self.total_collected_samples = 0
        self.samples_since_last_update = 0

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=config["lr"])
        self.loss = nn.MSELoss()

        self.send_weights()
        self.stopped = False
Пример #7
0
def main(argv):
    try:
        from torch.utils.tensorboard import SummaryWriter

        writer = SummaryWriter()
    except ImportError:
        writer = None

    torch.manual_seed(FLAGS.random_seed)

    np.random.seed(FLAGS.random_seed)
    if hasattr(torch, "cuda_is_available"):
        if torch.cuda_is_available():
            torch.cuda.manual_seed(FLAGS.random_seed)
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True

    device = torch.device(FLAGS.device)

    kwargs = {
        "num_workers": 1,
        "pin_memory": True
    } if FLAGS.device == "cuda" else {}
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            root=".",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose([
                # torchvision.transforms.
                #    RandomCrop(size=[28,28], padding=4)
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307, ), (0.3081, )),
            ]),
        ),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            root=".",
            train=False,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307, ), (0.3081, )),
            ]),
        ),
        batch_size=FLAGS.batch_size,
        **kwargs,
    )

    label = os.environ.get("SLURM_JOB_ID", str(uuid.uuid4()))
    if FLAGS.prefix:
        path = f"runs/mnist/{FLAGS.prefix}/{label}"
    else:
        path = f"runs/mnist/{label}"

    os.makedirs(path, exist_ok=True)
    os.chdir(path)
    FLAGS.append_flags_into_file("flags.txt")

    input_features = 28 * 28

    model = LIFConvNet(
        input_features,
        FLAGS.seq_length,
        model=FLAGS.model,
        device=device,
        only_first_spike=FLAGS.only_first_spike,
    ).to(device)

    if FLAGS.optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=FLAGS.learning_rate)
    elif FLAGS.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=FLAGS.learning_rate)

    if FLAGS.only_output:
        optimizer = torch.optim.Adam(model.out.parameters(),
                                     lr=FLAGS.learning_rate)

    training_losses = []
    mean_losses = []
    test_losses = []
    accuracies = []

    for epoch in range(FLAGS.epochs):
        training_loss, mean_loss = train(model,
                                         device,
                                         train_loader,
                                         optimizer,
                                         epoch,
                                         writer=writer)
        test_loss, accuracy = test(model,
                                   device,
                                   test_loader,
                                   epoch,
                                   writer=writer)

        training_losses += training_loss
        mean_losses.append(mean_loss)
        test_losses.append(test_loss)
        accuracies.append(accuracy)

        max_accuracy = np.max(np.array(accuracies))

        if (epoch % FLAGS.model_save_interval == 0) and FLAGS.save_model:
            model_path = f"mnist-{epoch}.pt"
            save(
                model_path,
                model=model,
                optimizer=optimizer,
                epoch=epoch,
                is_best=accuracy > max_accuracy,
            )

    np.save("training_losses.npy", np.array(training_losses))
    np.save("mean_losses.npy", np.array(mean_losses))
    np.save("test_losses.npy", np.array(test_losses))
    np.save("accuracies.npy", np.array(accuracies))
    model_path = "mnist-final.pt"
    save(
        model_path,
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        is_best=accuracy > max_accuracy,
    )
    if writer:
        writer.close()
Пример #8
0
from sklearn import preprocessing	

class M1(nn.Module):
    def __init__(self):
        super(Encode,self).__init__()
        self.conv1 = nn.Conv1d(3,3,10)
        self.conv2 = nn.Conv1d(3,3,10)

    def forward(self,x):
        x = x.permute(0,2,1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.permute(0,2,1)
        return x

device = 'cuda' if torch.cuda_is_available() else 'cpu'
# defining the model
model = M1()
# defining the optimizer
optimizer = optim.Adam(model.parameters())
# defining the loss function
criterion = nn.L1Loss()
# checking if GPU is available
if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()
    
epochs = 10

for i in range(epochs):
  arr = torch.empty((0,1,3)).to(device)
Пример #9
0
def main():
    running_reward = 10
    torch.manual_seed(FLAGS.random_seed)
    random.seed(FLAGS.random_seed)

    label = f"{FLAGS.policy}-{FLAGS.model}-{FLAGS.random_seed}"
    os.makedirs(f"runs/cartpole/{label}", exist_ok=True)
    os.chdir(f"runs/cartpole/{label}")
    FLAGS.append_flags_into_file("flags.txt")

    np.random.seed(FLAGS.random_seed)
    if hasattr(torch, "cuda_is_available"):
        if torch.cuda_is_available():
            torch.cuda.manual_seed(FLAGS.random_seed)

    env = gym.make(FLAGS.environment)
    env.reset()
    env.seed(FLAGS.random_seed)

    if FLAGS.policy == "ann":
        policy = ANNPolicy()
    elif FLAGS.policy == "snn":
        policy = Policy()
    elif FLAGS.policy == "lsnn":
        policy = LSNNPolicy(device=FLAGS.device,
                            model=FLAGS.model).to(FLAGS.device)
    optimizer = torch.optim.Adam(policy.parameters(), lr=FLAGS.learning_rate)

    running_rewards = []
    episode_rewards = []

    for e in range(FLAGS.episodes):
        state, ep_reward = env.reset(), 0

        for t in range(1, 10000):  # Don't infinite loop while learning
            action = select_action(state, policy, device=FLAGS.device)
            state, reward, done, _ = env.step(action)
            if FLAGS.render:
                env.render()
            policy.rewards.append(reward)
            ep_reward += reward
            if done:
                break

        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        finish_episode(policy, optimizer)

        if e % FLAGS.log_interval == 0:
            logging.info(
                "Episode {}/{} \tLast reward: {:.2f}\tAverage reward: {:.2f}".
                format(e, FLAGS.episodes, ep_reward, running_reward))
        episode_rewards.append(ep_reward)
        running_rewards.append(running_reward)
        if running_reward > env.spec.reward_threshold:
            logging.info("Solved! Running reward is now {} and "
                         "the last episode runs to {} time steps!".format(
                             running_reward, t))
            break

    np.save("running_rewards.npy", np.array(running_rewards))
    np.save("episode_rewards.npy", np.array(episode_rewards))
    torch.save(optimizer.state_dict(), "optimizer.pt")
    torch.save(policy.state_dict(), "policy.pt")
Пример #10
0
def main(args):
    try:
        from torch.utils.tensorboard import SummaryWriter

        writer = SummaryWriter()
    except ImportError:
        writer = None

    torch.manual_seed(FLAGS.random_seed)

    np.random.seed(FLAGS.random_seed)
    if hasattr(torch, "cuda_is_available"):
        if torch.cuda_is_available():
            torch.cuda.manual_seed(FLAGS.random_seed)
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True

    device = torch.device(FLAGS.device)

    constant_current_encoder = IFConstantCurrentEncoder(
        seq_length=FLAGS.seq_length, v_th=FLAGS.current_encoder_v_th)

    def polar_current_encoder(x):
        x_p, _ = constant_current_encoder(2 * torch.nn.functional.relu(x))
        x_m, _ = constant_current_encoder(2 * torch.nn.functional.relu(-x))
        return torch.cat((x_p, x_m), 1)

    def current_encoder(x):
        x, _ = constant_current_encoder(2 * x)
        return x

    def poisson_encoder(x):
        return poisson_train(x, seq_length=FLAGS.seq_length)

    def signed_poisson_encoder(x):
        return signed_poisson_train(x, seq_length=FLAGS.seq_length)

    def signed_current_encoder(x):
        z, _ = constant_current_encoder(torch.abs(x))
        return torch.sign(x) * z

    num_channels = 4

    if FLAGS.encoding == "poisson":
        encoder = poisson_encoder
    elif FLAGS.encoding == "constant":
        encoder = current_encoder
    elif FLAGS.encoding == "signed_poisson":
        encoder = signed_poisson_encoder
    elif FLAGS.encoding == "signed_constant":
        encoder = signed_current_encoder
    elif FLAGS.encoding == "constant_polar":
        encoder = polar_current_encoder
        num_channels = 2 * num_channels

    luminance_transforms = [
        add_luminance,
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465, 0.4816),
                                         (0.2023, 0.1994, 0.2010, 0.20013)),
    ]

    transform_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
    ] + luminance_transforms + [encoder])

    transform_test = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()] + luminance_transforms + [encoder])

    kwargs = {
        "num_workers": 0,
        "pin_memory": True
    } if FLAGS.device == "cuda" else {}
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(root=".",
                                     train=True,
                                     download=True,
                                     transform=transform_train),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(root=".",
                                     train=False,
                                     transform=transform_test),
        batch_size=FLAGS.batch_size,
        **kwargs,
    )

    label = os.environ.get("SLURM_JOB_ID", str(uuid.uuid4()))
    if not FLAGS.prefix:
        rundir = f"runs/cifar10/{label}"
    else:
        rundir = f"runs/cifar10/{FLAGS.prefix}/{label}"

    os.makedirs(rundir, exist_ok=True)
    os.chdir(rundir)
    FLAGS.append_flags_into_file("flags.txt")

    model = LIFConvNet(
        num_channels=num_channels,
        device=device,
    ).to(device)

    print(model)

    if device == "cuda":
        model = torch.nn.DataParallel(model).to(device)

    if FLAGS.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=FLAGS.learning_rate,
            momentum=0.9,
            weight_decay=5e-4 * FLAGS.batch_size,
            nesterov=True,
        )
    elif FLAGS.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=FLAGS.learning_rate)
    elif FLAGS.optimizer == "rms":
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=FLAGS.learning_rate)

    if FLAGS.only_output:
        optimizer = torch.optim.Adam(model.out.parameters(),
                                     lr=FLAGS.learning_rate)

    if FLAGS.resume:
        if os.path.isfile(FLAGS.resume):
            checkpoint = torch.load(FLAGS.resume)
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])

    if FLAGS.learning_rate_schedule:
        lr_scheduler = PiecewiseLinear(FLAGS.batch_size, [0, 5, FLAGS.epochs],
                                       [0, 0.4, 0])
    else:
        lr_scheduler = None

    training_losses = []
    mean_losses = []
    test_losses = []
    accuracies = []

    start = datetime.datetime.now()
    for epoch in range(FLAGS.start_epoch, FLAGS.start_epoch + FLAGS.epochs):
        training_loss, mean_loss = train(
            model,
            device,
            train_loader,
            optimizer,
            epoch,
            lr_scheduler=lr_scheduler,
            writer=writer,
        )
        test_loss, accuracy = test(model,
                                   device,
                                   test_loader,
                                   epoch,
                                   writer=writer)

        training_losses += training_loss
        mean_losses.append(mean_loss)
        test_losses.append(test_loss)
        accuracies.append(accuracy)

        if (epoch % FLAGS.model_save_interval == 0) and FLAGS.save_model:
            model_path = f"cifar10-{epoch}.pt"
            save(model_path, model, optimizer)

    stop = datetime.datetime.now()

    np.save("training_losses.npy", np.array(training_losses))
    np.save("mean_losses.npy", np.array(mean_losses))
    np.save("test_losses.npy", np.array(test_losses))
    np.save("accuracies.npy", np.array(accuracies))
    model_path = "cifar10-final.pt"
    save(model_path, model, optimizer)

    logging.info(f"output saved to {rundir}")
    logging.info(f"{start - stop}")
    if writer:
        writer.close()
from genrl import DQN
from mario.supervised.adversarial import AdversariaTrainer
from mario.base.wrapper import MarioEnv

argument_parser = argparse.ArgumentParser(
    description="A script used to train agent adversarial on a dataset.")
argument_parser.add_argument("-i", "--input-path", type=str, required=True)
argument_parser.add_argument("-e", "--epochs", type=int, default=10)
argument_parser.add_argument("--lr", type=float, default=1e-3)
argument_parser.add_argument("-b", "--batch-size", type=int, default=64)
argument_parser.add_argument("-l", "--length", type=int, default=None)
argument_parser.add_argument("--enable-cuda", action="store_true")
args = argument_parser.parse_args()

if args.enable_cuda:
    if torch.cuda_is_available():
        device = "cuda"
    else:
        device = "cpu"
        warnings.warn("cuda is ot available. Defaulting to cpu")
else:
    device = "cpu"

env = gym_super_mario_bros.make("SuperMarioBros-v0")
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = MarioEnv(env)
agent = DQN("cnn", env, replay_size=100000, epsilon_decay=100000)
trainer = AdversariaTrainer(
    agent=agent,
    env=env,
    dataset=args.input_path,
Пример #12
0
    # ------------------------------------------

    parser_train = subparsers.add_parser(
        "train",
        help="Command to train the model",
    )
    parser_train.add_argument(
        "--epochs", help="Number of epochs", default=42, type=int)
    parser_train.add_argument(
        "--batch-size", help="Batch size", default=128, type=int)
    parser_train.set_defaults(func=train)

    # Evaluate parser
    # ------------------------------------------

    parser_evaluate = subparsers.add_parser(
        "evaluate",
        help="Command to evaluate an existing model"
    )
    parser_evaluate.set_defaults(func=evaluate)

    # Parse arguments
    # ------------------------------------------

    args = parser.parse_args()

    use_cuda = args.cuda and torch.cuda_is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    args.func(args)
Пример #13
0
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary

batch_size = 100
total_epoch = 50
learning_rate = 0.01
device = torch.cuda_is_available()
criterion = nn.CrossEntropyLoss()

print(use_cuda)

train_dataset = dsets.CIFAR10(root='./data',
                              train=True,
                              transform=transforms.ToTensor(),
                              download=True)
test_dataset = dsets.CIFAR10(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)
Пример #14
0
def main(args):
    #device configuration
    if args.cpu != None:
        device = torch.device('cpu')
    elif args.gpu != None:
        if not torch.cuda_is_available():
            print("GPU / cuda reported as unavailable to torch")
            exit(0)
        device = torch.device('cuda')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create model directory
    if not os.path.exists(args.model_save_dir):
        os.makedirs(args.model_save_dir)

    train_data = ld.get_data(labels_file=args.labels_file,
                             root_dir=args.train_image_dir,
                             mode="absolute")

    validation_data = ld.get_data(labels_file=args.labels_file,
                                  root_dir=args.validation_image_dir,
                                  mode="absolute")

    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(dataset=validation_data,
                            batch_size=args.validation_batch_size)

    # Build the models
    if args.num_layers != None and args.block_type != None:
        if args.block_type == "bottleneck":
            net = model.ResNet(model.Bottleneck,
                               args.num_layers,
                               dropout=args.dropout)
        else:
            net = model.ResNet(model.BasicBlock,
                               args.num_layers,
                               dropout=args.dropout)
    else:
        if args.resnet_model == 152:
            net = model.ResNet152(args.dropout)
        elif args.resnet_model == 101:
            net = model.ResNet101(args.dropout)
        elif args.resnet_model == 50:
            net = model.ResNet50(args.dropout)
        elif args.resnet_model == 34:
            net = model.ResNet34(args.dropout)
        else:
            net = model.ResNet101(args.dropout)

    #load the model to the appropriate device
    net = net.to(device)
    params = net.parameters()

    # Loss and optimizer
    criterion = nn.MSELoss()  #best for regression

    if args.optim != None:
        if args.optim == "adadelta":
            optimizer = torch.optim.Adadelta(params, lr=args.learning_rate)
        if args.optim == "adagrad":
            optimizer = torch.optim.Adagrad(params, lr=args.learning_rate)
        if args.optim == "adam":
            optimizer = torch.optim.Adam(params, lr=args.learning_rate)
        if args.optim == "adamw":
            optimizer = torch.optim.AdamW(params, lr=args.learning_rate)
        if args.optim == "rmsprop":
            optimizer = torch.optim.RMSProp(params, lr=args.learning_rate)
        if args.optim == "sgd":
            optimizer = torch.optim.SGD(params, lr=args.learning_rate)
    else:
        optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    val_acc_history = []
    train_acc_history = []
    failed_runs = 0
    prev_loss = float("inf")

    for epoch in range(args.num_epochs):
        running_loss = 0.0
        total_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader, 0):
            net.train()

            #adjust to output image coordinates
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs.float())
            loss = criterion(outputs.float(), labels.float())
            loss.backward()

            torch.nn.utils.clip_grad_norm_(net.parameters(),
                                           args.clipping_value)
            optimizer.step()
            running_loss += loss.item()
            total_loss += loss.item()
            if i % 2 == 0:  #print every mini-batches
                print('[%d, %5d] loss: %.5f' %
                      (epoch + 1, i + 1, running_loss / 2))
                running_loss = 0.0

        loss = 0.0

        #compute validation loss at the end of the epoch
        for i, (inputs, labels) in enumerate(val_loader, 0):
            inputs, labels = inputs.to(device), labels.to(device)
            net.eval()
            with torch.no_grad():
                outputs = net(inputs.float())
                loss += criterion(outputs, labels.float()).item()

        print("------------------------------------------------------------")
        print("Epoch %5d" % (epoch + 1))
        print("Training loss: {}, Avg Loss: {}".format(
            total_loss, total_loss / train_data.__len__()))
        print("Validation Loss: {}, Avg Loss: {}".format(
            loss, loss / validation_data.__len__()))
        print("------------------------------------------------------------")

        val_acc_history.append(loss)
        train_acc_history.append(total_loss)

        #save the model at the desired step
        if (epoch + 1) % args.save_step == 0:
            torch.save(net.state_dict(),
                       args.model_save_dir + "resnet" + str(epoch + 1) + ".pt")

        ##stopping conditions
        if failed_runs > 5 and prev_loss < loss:
            break
        elif prev_loss < loss:
            failed_runs += 1
        else:
            failed_runs = 0

        prev_loss = loss

    #create a plot of the loss
    plt.title("Training vs Validation Accuracy")
    plt.xlabel("Training Epochs")
    plt.ylabel("Loss")
    plt.plot(range(1,
                   len(val_acc_history) + 1),
             val_acc_history,
             label="Validation loss")
    plt.plot(range(1,
                   len(train_acc_history) + 1),
             train_acc_history,
             label="Training loss")
    plt.xticks(np.arange(1, len(train_acc_history) + 1, 1.0))
    plt.legend()
    plt.ylim((0, max([max(val_acc_history), max(train_acc_history)])))

    if args.save_training_plot != None:
        plt.savefig(args.save_training_plot + "loss_plot.png")

    plt.show()
    print('Finished Training')