Esempio n. 1
0
def main():
    
    # Get input arguments
    args = arg_parser()
    
    # Process and load the data/images
    image_datasets, dataloaders = helper.process_and_load_data(args.data_dir)
    print("The train, test & validation data has been loaded.".format(key))
    
    # Load the model
    model, optimizer, criterion = helper.build_model(args.arch, args.hidden_units, args.learning_rate)
    print("Model, optimizer & criterion have been loaded.")
    
    # Check if GPU is available
    device = helper.check_gpu(args.gpu)
    print('Using {} for computation.'.format(device))
    
    # Train and validate the model
    helper.train_and_validate_model(model, optimizer, criterion, dataloaders, device, args.epochs, print_every = 32)
    print("Training has been completed.")
    
    # Test the model
    helper.test_model(model, optimizer, criterion, dataloaders, device)
    print("Testing has been completed.")
    
    # Save the checkpoint
    helper.save_checkpoint(args.arch, model, args.epochs, args.hidden_units, args.learning_rate, image_datasets, args.save_dir)
    print("Model's checkpoint has been saved.")
Esempio n. 2
0
def run(args, model, device, train_loader, test_loader, scheduler, optimizer):
    global best_prec1
    for epoch in range(args.start_epoch,
                       args.epochs):  # loop over the dataset multiple times
        train(args, model, device, train_loader, optimizer, epoch)
        prec1, prec5, loss = test(args, model, device, test_loader)
        # scheduler
        if args.scheduler == "MultiStepLR":
            scheduler.step()
        elif args.scheduler == "ReduceLROnPlateau":
            scheduler.step(loss)
        else:
            pass

        # remember the best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if args.detail and is_best:
            save_checkpoint(
                {
                    'epoch': args.epochs + 1,
                    'arch': args.model_type,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict()
                }, is_best, args.checkpoint_path + '_' + args.model_type +
                '_' + str(args.model_structure))
Esempio n. 3
0
 def train_epochs(self, train_corpus, dev_corpus, start_epoch, n_epochs):
     """Trains model for n_epochs epochs"""
     for epoch in range(start_epoch, start_epoch + n_epochs):
         if not self.stop:
             print('\nTRAINING : Epoch ' + str((epoch + 1)))
             self.train(train_corpus)
             # training epoch completes, now do validation
             print('\nVALIDATING : Epoch ' + str((epoch + 1)))
             dev_loss = self.validate(dev_corpus)
             self.dev_losses.append(dev_loss)
             print('validation loss = %.4f' % dev_loss)
             # save model if dev loss goes down
             if self.best_dev_loss == -1 or self.best_dev_loss > dev_loss:
                 self.best_dev_loss = dev_loss
                 helper.save_checkpoint(
                     {
                         'epoch': (epoch + 1),
                         'state_dict': self.model.state_dict(),
                         'best_loss': self.best_dev_loss,
                         'optimizer': self.optimizer.state_dict(),
                     }, self.config.save_path + 'model_best.pth.tar')
                 self.times_no_improvement = 0
             else:
                 self.times_no_improvement += 1
                 # no improvement in validation loss for last n iterations, so stop training
                 if self.times_no_improvement == 5:
                     self.stop = True
             # save the train and development loss plot
             helper.save_plot(self.train_losses, self.config.save_path,
                              'training', epoch + 1)
             helper.save_plot(self.dev_losses, self.config.save_path, 'dev',
                              epoch + 1)
         else:
             break
Esempio n. 4
0
 def __create_checkpoint(self, loss_item, epoch):
     """create checkpoints for decoder
     and encoder parts"""
     checkpoint_path = get_checkpoint_path(self.path, "encoder", epoch)
     save_checkpoint(loss_item, self.model.get_encoder_state_dict(),
                     checkpoint_path)
     checkpoint_path = get_checkpoint_path(self.path, "decoder", epoch)
     save_checkpoint(loss_item, self.model.get_decoder_state_dict(),
                     checkpoint_path)
def main():

    train_data, valid_data, test_data, train_loader, valid_loader, test_loader = helper.dataloaders(
        location)
    model, model.name, model.classifier, criterion, optimizer = helper.nn_class(
        structure, hiddenlayer1, dropout, lr, power)
    helper.model_processing(model, train_loader, valid_loader, criterion,
                            optimizer, epochs, 5, power)
    helper.save_checkpoint(model, train_data, path, hiddenlayer, dropout, lr)
    print("Training Done!")
Esempio n. 6
0
 def train_epochs(self, train_corpus, dev_corpus, test_corpus, start_epoch,
                  n_epochs):
     """Trains model for n_epochs epochs"""
     for epoch in range(start_epoch, start_epoch + n_epochs):
         if not self.stop:
             print('\nTRAINING : Epoch ' + str((epoch + 1)))
             self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr'] * self.config.lr_decay \
                 if epoch > start_epoch and 'sgd' in self.config.optimizer else self.optimizer.param_groups[0]['lr']
             if 'sgd' in self.config.optimizer:
                 print('Learning rate : {0}'.format(
                     self.optimizer.param_groups[0]['lr']))
             self.train(train_corpus)
             # training epoch completes, now do validation
             print('\nVALIDATING : Epoch ' + str((epoch + 1)))
             dev_acc = self.validate(dev_corpus)
             self.dev_accuracies.append(dev_acc)
             print('validation acc = %.2f%%' % dev_acc)
             test_acc = self.validate(test_corpus)
             print('validation acc = %.2f%%' % test_acc)
             # save model if dev accuracy goes up
             if self.best_dev_acc < dev_acc:
                 self.best_dev_acc = dev_acc
                 helper.save_checkpoint(
                     {
                         'epoch': (epoch + 1),
                         'state_dict': self.model.state_dict(),
                         'best_acc': self.best_dev_acc,
                         'optimizer': self.optimizer.state_dict(),
                     }, self.config.save_path + 'model_best.pth.tar')
                 self.times_no_improvement = 0
             else:
                 if 'sgd' in self.config.optimizer:
                     self.optimizer.param_groups[0][
                         'lr'] = self.optimizer.param_groups[0][
                             'lr'] / self.config.lrshrink
                     print('Shrinking lr by : {0}. New lr = {1}'.format(
                         self.config.lrshrink,
                         self.optimizer.param_groups[0]['lr']))
                     if self.optimizer.param_groups[0][
                             'lr'] < self.config.minlr:
                         self.stop = True
                 if 'adam' in self.config.optimizer:
                     self.times_no_improvement += 1
                     # early stopping (at 'n'th decrease in accuracy)
                     if self.times_no_improvement == self.config.early_stop:
                         self.stop = True
             # save the train loss and development accuracy plot
             helper.save_plot(self.train_accuracies, self.config.save_path,
                              'training_acc_plot_', epoch + 1)
             helper.save_plot(self.dev_accuracies, self.config.save_path,
                              'dev_acc_plot_', epoch + 1)
         else:
             break
Esempio n. 7
0
def main():
    # Create parser object
    parser = argparse.ArgumentParser(description="Neural Network Image Classifier Training")
    
    # Define argument for parser object
    parser.add_argument('--save_dir', type=str, help='Define save directory for checkpoints model as a string')
    parser.add_argument('--arch', dest='arch', action ='store', type = str, default = 'densenet', help='choose a tranfer learning model or architechture')
    parser.add_argument('--learning_rate', dest = 'learning_rate', action='store', type=float, default=0.001, help='Learning Rate for Optimizer')
    parser.add_argument('--hidden_units', dest = 'hidden_units', action='store', type=int, default=512, help='Define number of hidden unit')
    parser.add_argument('--epochs', dest = 'epochs', action='store', type=int, default=1, help='Number of Training Epochs')
    parser.add_argument('--gpu', dest = 'gpu', action='store_true', default = 'False', help='Use GPU if --gpu')
    parser.add_argument('--st', action = 'store_true', default = False, dest = 'start', help = '--st to start predicting')
                        
    # Parse the argument from standard input
    args = parser.parse_args()
    
    # Print out the passing/default parameters
    print('-----Parameters------')
    print('gpu              = {!r}'.format(args.gpu))
    print('epoch(s)         = {!r}'.format(args.epochs))
    print('arch             = {!r}'.format(args.arch))
    print('learning_rate    = {!r}'.format(args.learning_rate))
    print('hidden_units     = {!r}'.format(args.hidden_units))
    print('start            = {!r}'.format(args.start))
    print('----------------------')
    
    if args.start == True:
        class_labels, trainloaders, testloaders, validloaders = helper.load_image()
        model = helper.load_pretrained_model(args.arch, args.hidden_units)
        criterion = nn.NLLLoss()
        optimizer = optim.Adam(model.classifier.parameters(), lr = args.learning_rate)
        helper.train_model(model, args.learning_rate, criterion, trainloaders, validloaders, args.epochs, args.gpu)
        helper.test_model(model, testloaders, args.gpu)
        model.to('cpu')
        
        # saving checkpoints
        helper.save_checkpoint({
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'hidden_units': args.hidden_units,
            'class_labels': class_labels
        })
        print('model checkpoint saved')
def main():
    args = parse_args()

    # if save_dir does not exist, create the directory
    Path(args.save_dir).mkdir(parents=True, exist_ok=True)

    image_datasets, dataloaders = transformers(args)

    model, criterion, optimizer = classifier(args.arch, float(args.dropout),
                                             int(args.hidden_units),
                                             float(args.learning_rate),
                                             args.gpu)

    model.class_to_idx = image_datasets['train'].class_to_idx
    train(model, criterion, optimizer, dataloaders["train"],
          dataloaders["validate"], int(args.epochs), args.gpu)
    test_network(model, dataloaders["test"], criterion, args.gpu)
    save_checkpoint(args.save_dir, model, optimizer, args.arch,
                    int(args.epochs), float(args.learning_rate))
def main():

    # Construct model from Command Line Arguments
    model, cli_args = arg_parser()
    
    # Move args from client arg namespace
    args = vars(cli_args)

    # Prepare dataloaders and datasets
    loaders = prep_dataset(args['data_directory'])
    dataset = get_datasets(args['data_directory'])

    # Add model class-to-index mapping as model attribute
    model.class_to_idx = dataset['train'].class_to_idx

    # Train model save checkpoint and test model
    train_model(model, loaders['train'], loaders['valid'], epochs = args['epochs'], optimizer = optim.Adam(model.classifier.parameters(), lr=args['learning_rate']), device = True)
    save_checkpoint(model, loaders['test'], args['save_dir'])
    test_model(model, loaders['test'])
Esempio n. 10
0
    def train(epoch):
        data_loader.train()
        model.train()
        train_gen_loss_accum, train_dis_loss_accum, train_likelihood_accum, train_kl_accum, batch_size_accum = 0, 0, 0, 0, 0
        start = time.time();
        for batch_idx, curr_batch_size, batch in data_loader: 

            disc_train_step_np, discriminator_cost_np = \
                sess.run([train_discriminator_step_tf, train_outs_dict['discriminator_cost']], feed_dict = input_dict_func(batch, np.asarray([model.train_mode,])))

            train_dis_loss_accum += curr_batch_size*discriminator_cost_np
            batch_size_accum += curr_batch_size

            if batch_idx % global_args.log_interval == 0:
                end = time.time();
                print('Train: Epoch {} [{:7d} ()]\tDiscriminator Cost: {:.6f}\tTime: {:.3f}'.format(
                      epoch, batch_idx * curr_batch_size, discriminator_cost_np, (end - start)))

                with open(global_args.exp_dir+"training_traces.txt", "a") as text_file:
                    text_file.write(str(discriminator_cost_np) + '\n')
                start = time.time()
    
        summary_str = sess.run(merged_summaries, feed_dict = input_dict_func(batch, np.asarray([0,])))
        summary_writer.add_summary(summary_str, (tf.train.global_step(sess, global_step)))
        
        checkpoint_time = 20
        if epoch % checkpoint_time == 0:
            print('====> Average Train: Epoch: {}\tDiscriminator Cost: {:.6f}'.format(epoch, train_dis_loss_accum/batch_size_accum))

            distributions.visualizeProductDistribution2(sess, input_dict_func(batch, np.asarray([model.train_mode,])), batch, model.obs_dist,
            save_dir=global_args.exp_dir+'Visualization/Train/', postfix='train')

            checkpoint_path1 = global_args.exp_dir+'checkpoint/'
            checkpoint_path2 = global_args.exp_dir+'checkpoint2/'
            print('====> Saving checkpoint. Epoch: ', epoch); start_tmp = time.time()
            helper.save_checkpoint(saver, sess, global_step, checkpoint_path1) 
            end_tmp = time.time(); print('Checkpoint path: '+checkpoint_path1+'   ====> It took: ', end_tmp - start_tmp)
            if epoch % 60 == 0: 
                print('====> Saving checkpoint backup. Epoch: ', epoch); start_tmp = time.time()
                helper.save_checkpoint(saver, sess, global_step, checkpoint_path2) 
                end_tmp = time.time(); print('Checkpoint path: '+checkpoint_path2+'   ====> It took: ', end_tmp - start_tmp)
Esempio n. 11
0
 def train_epochs(self, start_epoch, n_epochs):
     """Trains model for n_epochs epochs"""
     for epoch in range(start_epoch, start_epoch + n_epochs):
         if not self.stop:
             print('\nTRAINING : Epoch ' + str((epoch + 1)))
             self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr'] * self.config.lr_decay \
                 if (epoch + 1) > 1 and 'sgd' in self.config.optimizer else self.optimizer.param_groups[0]['lr']
             print('Learning rate : {0}'.format(self.optimizer.param_groups[0]['lr']))
             self.train()
             # training epoch completes, now do validation
             print('\nVALIDATING : Epoch ' + str((epoch + 1)))
             dev_acc = self.validate()
             self.dev_accuracies.append(dev_acc)
             print('validation accuracy = %.2f' % dev_acc)
             # save model if dev loss goes down
             if self.best_dev_acc < dev_acc:
                 self.best_dev_acc = dev_acc
                 helper.save_checkpoint({
                     'epoch': (epoch + 1),
                     'state_dict': helper.get_state_dict(self.model, self.config),
                     'best_acc': self.best_dev_acc,
                     'optimizer': self.optimizer.state_dict(),
                 }, self.config.save_path + 'model_best.pth')
                 self.times_no_improvement = 0
             else:
                 if 'sgd' in self.config.optimizer:
                     self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0][
                                                                'lr'] / self.config.lrshrink
                     print('Shrinking lr by : {0}. New lr = {1}'.format(self.config.lrshrink,
                                                                        self.optimizer.param_groups[0]['lr']))
                     if self.optimizer.param_groups[0]['lr'] < self.config.minlr:
                         self.stop = True
                 if 'adam' in self.config.optimizer:
                     self.times_no_improvement += 1
                     # early stopping (at 3rd decrease in accuracy)
                     if self.times_no_improvement >= 5:
                         self.stop = True
         else:
             break
Esempio n. 12
0
def main():
    input_args = get_input_args()
    gpu = torch.cuda.is_available() and input_args.gpu

    dataloaders, class_to_idx = helper.get_dataloders(input_args.data_dir)

    model, optimizer, criterion = helper.model_create(
        input_args.architectures,
        input_args.learning_rate,
        input_args.hidden_units,
        class_to_idx
        )

    if gpu:
        model.cuda()
        criterion.cuda()
    else:
        torch.set_num_threads(input_args.num_threads)

    epochs = 3
    print_every = 40
    helper.train(model, dataloaders['training'], epochs, print_every, criterion, optimizer, device='cpu')

    if input_args.save_dir:
        if not os.path.exists(input_args.save_dir):
            os.makedirs(input_args.save_dir)

        file_path = input_args.save_dir + '/' + input_args.architectures + '_checkpoint.pth'
    else:
        file_path = input_args.architectures + '_checkpoint.pth'

    helper.save_checkpoint(file_path,
                            model, optimizer,
                            input_args.architectures,
                            input_args.learning_rate,
                            input_args.epochs
                            )

    helper.validation(model, dataloaders['testing'], criterion)
Esempio n. 13
0
                    dest="fc2",
                    action="store",
                    default=1024,
                    help="state the units for fc2")

pa = parser.parse_args()
data_path = pa.data_dir
filepath = pa.save_dir
learn_r = pa.learn_r
architecture = pa.architecture
dropout = pa.dropout
fc2 = pa.fc2
gpu_cpu = pa.gpu_cpu
epoch_num = pa.epoch_num

# load the data - data_load() from help.py
trainloader, validationloader, testloader = hp.load_data(data_path)

# build model
model, optimizer, criterion = hp.nn_architecture(architecture, dropout, fc2,
                                                 learn_r)

# train model
hp.train_network(model, criterion, optimizer, trainloader, validationloader,
                 epoch_num, 20, gpu_cpu)

# checkpoint the model
hp.save_checkpoint(filepath, architecture, dropout, learn_r, fc2, epoch_num)

print("model has been successfully trained")
Esempio n. 14
0
                    type=int,
                    dest="hidden_units",
                    action="store",
                    default=120,
                    help="state the units for fist hidden layer (deafult 120)")

pa = parser.parse_args()
data_path = pa.data_dir
path = pa.save_dir
lr = pa.learning_rate
structure = pa.arch
dropout = pa.dropout
hidden_layer1 = pa.hidden_units
power = pa.gpu
epochs = pa.epochs

#load the data - invoke the data_load method from helper
trainloader, v_loader, testloader = hp.data_load(data_path)

#create the model
model, optimizer, criterion = hp.nn_arch(structure, dropout, hidden_layer1, lr,
                                         power)

#train the neural network
hp.train_network(model, optimizer, criterion, epochs, 20, trainloader, power)

#save  the train network checkpoint
hp.save_checkpoint(path, structure, hidden_layer1, dropout, lr)

print("The Model is trained")
Esempio n. 15
0
results = parser.parse_args()
print('---------Parameters----------')
print('gpu              = {!r}'.format(results.gpu))
print('epoch(s)         = {!r}'.format(results.epochs))
print('arch             = {!r}'.format(results.arch))
print('learning_rate    = {!r}'.format(results.learning_rate))
print('hidden_units     = {!r}'.format(results.hidden_units))
print('start            = {!r}'.format(results.start))
print('-----------------------------')

if results.start == True:
    class_labels, trainloader, testloader, validloader = helper.load_img()
    model = helper.load_pretrained_model(results.arch, results.hidden_units)
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.classifier.parameters(),
                           lr=results.learning_rate)
    helper.train_model(model, results.learning_rate, criterion, trainloader,
                       validloader, results.epochs, results.gpu)
    helper.test_model(model, testloader, results.gpu)
    model.to('cpu')

    # Save Checkpoint for predection
    helper.save_checkpoint({
        'arch': results.arch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'hidden_units': results.hidden_units,
        'class_labels': class_labels
    })
    print('Checkpoint has been saved.')
Esempio n. 16
0
                dest="learning_rate",
                default=0.001)
ap.add_argument('--dropout', type=float, dest="dropout", default=0.5)

pa = ap.parse_args()

data_dir = pa.data_dir
path = pa.save_dir
structure = pa.arch
power = pa.gpu
epochs = pa.epochs
hidden_layer = pa.hidden_units
lr = pa.learning_rate
dropout = pa.dropout

# load the 3 dataloaders (takes data_dir as argument)
trainloader, validloader, testloader, class_to_idx = helper.load_data(data_dir)

# set up model structure (takes structure, dropout, hidden_layer, lr, power)
model, optimizer, criterion = helper.nn_setup(structure, dropout, hidden_layer,
                                              lr, power)

# train and validate the network
helper.train_network(model, optimizer, criterion, epochs, 20, trainloader,
                     validloader, power)

helper.save_checkpoint(model, path, structure, hidden_layer, dropout, lr,
                       class_to_idx)

print("------------Model Trained!------------")
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.result = args_new.result
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.result = args_new.result
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)
Esempio n. 18
0
    def train(epoch):
        data_loader.train()
        train_gen_loss_accum, train_dis_loss_accum, train_likelihood_accum, train_kl_accum, batch_size_accum = 0, 0, 0, 0, 0
        start = time.time();
        for batch_idx, curr_batch_size, batch in data_loader: 

            disc_train_step_np = sess.run([train_discriminator_step_tf], feed_dict = input_dict_func(batch, np.asarray([0,])))
            if batch_idx % 5 !=0: continue
            gen_train_step_np, generator_cost_np, discriminator_cost_np = \
                sess.run([train_generator_step_tf, train_outs_dict['generator_cost'], train_outs_dict['discriminator_cost']],
                          feed_dict = input_dict_func(batch, np.asarray([0,])))

            max_discriminator_weight = sess.run(max_abs_discriminator_vars)
            train_gen_loss_accum += curr_batch_size*generator_cost_np
            train_dis_loss_accum += curr_batch_size*discriminator_cost_np
            batch_size_accum += curr_batch_size

            if batch_idx % global_args.log_interval == 0:
                end = time.time();
                print('Train: Epoch {} [{:7d} ()]\tGenerator Cost: {:.6f}\tDiscriminator Cost: {:.6f}\tTime: {:.3f}, Max disc weight {:.6f}'.format(
                      epoch, batch_idx * curr_batch_size, generator_cost_np, discriminator_cost_np, (end - start), max_discriminator_weight))

                with open(global_args.exp_dir+"training_traces.txt", "a") as text_file:
                    text_file.write(str(generator_cost_np) + ', ' + str(discriminator_cost_np) + '\n')
                start = time.time()
    
        summary_str = sess.run(merged_summaries, feed_dict = input_dict_func(batch, np.asarray([0,])))
        summary_writer.add_summary(summary_str, (tf.train.global_step(sess, global_step)))
        
        checkpoint_time = 1
        if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
            checkpoint_time = 20

        if epoch % checkpoint_time == 0:
            print('====> Average Train: Epoch: {}\tGenerator Cost: {:.6f}\tDiscriminator Cost: {:.6f}'.format(
                  epoch, train_gen_loss_accum/batch_size_accum, train_dis_loss_accum/batch_size_accum))

            # helper.draw_bar_plot(rate_similarity_gen_np[:,0,0], y_min_max = [0,1], save_dir=global_args.exp_dir+'Visualization/inversion_weight/', postfix='inversion_weight'+str(epoch))
            # helper.draw_bar_plot(effective_z_cost_np[:,0,0], thres = [np.mean(effective_z_cost_np), np.max(effective_z_cost_np)], save_dir=global_args.exp_dir+'Visualization/inversion_cost/', postfix='inversion_cost'+str(epoch))
            # helper.draw_bar_plot(disc_cost_gen_np[:,0,0], thres = [0, 0], save_dir=global_args.exp_dir+'Visualization/disc_cost/', postfix='disc_cost'+str(epoch))
            
            if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
                helper.visualize_datasets(sess, input_dict_func(batch), data_loader.dataset, generative_dict['obs_sample_out'], generative_dict['latent_sample_out'],
                    save_dir=global_args.exp_dir+'Visualization/', postfix=str(epoch)) 
                
                xmin, xmax, ymin, ymax, X_dense, Y_dense = -3.5, 3.5, -3.5, 3.5, 250, 250
                xlist = np.linspace(xmin, xmax, X_dense)
                ylist = np.linspace(ymin, ymax, Y_dense)
                X, Y = np.meshgrid(xlist, ylist)
                XY = np.concatenate([X.reshape(-1,1), Y.reshape(-1,1)], axis=1)

                batch['observed']['data']['flat'] = XY[:, np.newaxis, :]
                disc_cost_real_np = sess.run(train_outs_dict['critic_real'], feed_dict = input_dict_func(batch, np.asarray([0,])))

                f = np.reshape(disc_cost_real_np[:,0,0], [Y_dense, X_dense])
                helper.plot_ffs(X, Y, f, save_dir=global_args.exp_dir+'Visualization/discriminator_function/', postfix='discriminator_function'+str(epoch))
            else:
                distributions.visualizeProductDistribution(sess, input_dict_func(batch), batch, inference_obs_dist, generative_dict['obs_dist'], 
                save_dir=global_args.exp_dir+'Visualization/Train/', postfix='train_'+str(epoch))

            checkpoint_path1 = global_args.exp_dir+'checkpoint/'
            checkpoint_path2 = global_args.exp_dir+'checkpoint2/'
            print('====> Saving checkpoint. Epoch: ', epoch); start_tmp = time.time()
            helper.save_checkpoint(saver, sess, global_step, checkpoint_path1) 
            end_tmp = time.time(); print('Checkpoint path: '+checkpoint_path1+'   ====> It took: ', end_tmp - start_tmp)
            if epoch % 60 == 0: 
                print('====> Saving checkpoint backup. Epoch: ', epoch); start_tmp = time.time()
                helper.save_checkpoint(saver, sess, global_step, checkpoint_path2) 
                end_tmp = time.time(); print('Checkpoint path: '+checkpoint_path2+'   ====> It took: ', end_tmp - start_tmp)
                 dest="arch",
                 action="store",
                 default="resnet50",
                 type=str)
arg.add_argument('--hidden_unit',
                 type=int,
                 dest="hidden_units",
                 action="store",
                 default=768)
arg.add_argument('--compute', dest="compute", action="store", default="cuda")

parser = arg.parse_args()
loc = parser.get_data
filepath = parser.save_model
lr = parser.learning_rate
architecture = parser.arch
dropout = parser.dropout
first_hidden_layer = parser.hidden_units
epochs = parser.epochs
compute = parser.compute

trainloader, validloader, testloader, class_to_idx = helper.get_data(loc)
model, criterion, optimizer = helper.model_setup(architecture, dropout,
                                                 first_hidden_layer, lr,
                                                 compute)
helper.train_model(model, criterion, optimizer, epochs, compute, trainloader,
                   validloader)
helper.save_checkpoint(filepath, architecture, first_hidden_layer, lr, dropout,
                       model, class_to_idx)
print("Done")
Esempio n. 20
0
def main():
    print(
        "hello"
    )  # just to make sure I can see something before it takes forever to trian
    args = parse_args()

    data_dir = 'flowers'
    train_dir = data_dir + '/train'
    val_dir = data_dir + '/valid'
    test_dir = data_dir + '/test'

    training_transforms = transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    validataion_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    testing_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    image_datasets = [
        ImageFolder(train_dir, transform=training_transforms),
        ImageFolder(val_dir, transform=validataion_transforms),
        ImageFolder(test_dir, transform=testing_transforms)
    ]

    dataloaders = [
        torch.utils.data.DataLoader(image_datasets[0],
                                    batch_size=64,
                                    shuffle=True),
        torch.utils.data.DataLoader(image_datasets[1],
                                    batch_size=64,
                                    shuffle=True),
        torch.utils.data.DataLoader(image_datasets[2],
                                    batch_size=64,
                                    shuffle=True)
    ]

    model = getattr(models, args.arch)(pretrained=True)

    for param in model.parameters():
        param.requires_grad = False

    if args.arch == "vgg13":
        feature_num = model.classifier[0].in_features
        classifier = nn.Sequential(
            OrderedDict([('fc1', nn.Linear(feature_num, 1024)),
                         ('drop', nn.Dropout(p=0.5)), ('relu', nn.ReLU()),
                         ('fc2', nn.Linear(1024, 102)),
                         ('output', nn.LogSoftmax(dim=1))]))
    elif args.arch == "densenet121":
        classifier = nn.Sequential(
            OrderedDict([('fc1', nn.Linear(1024, 500)),
                         ('drop', nn.Dropout(p=0.6)), ('relu1', nn.ReLU()),
                         ('fc2', nn.Linear(500, 256)),
                         ('drop', nn.Dropout(p=0.6)), ('relu2', nn.ReLU()),
                         ('fc3', nn.Linear(256, 102)),
                         ('output', nn.LogSoftmax(dim=1))]))

    model.classifier = classifier
    criterion = nn.NLLLoss(
    )  # using criterion and optimizer similar to pytorch lectures (densenet)
    optimizer = optim.Adam(model.classifier.parameters(),
                           lr=float(args.learning_rate))
    epochs = int(args.epochs)
    class_index = image_datasets[0].class_to_idx
    gpu = args.gpu  # get the gpu settings
    train(model, criterion, optimizer, dataloaders, epochs, gpu)
    model.class_to_idx = class_index
    path = args.save_dir  # get the new save location
    save_checkpoint(path, model, optimizer, args, classifier)
Esempio n. 21
0
def train(args,
          params,
          train_configs,
          model,
          optimizer,
          current_lr,
          comet_tracker=None,
          resume=False,
          last_optim_step=0,
          reverse_cond=None):
    # getting data loaders
    train_loader, val_loader = data_handler.init_data_loaders(args, params)

    # adjusting optim step
    optim_step = last_optim_step + 1 if resume else 1
    max_optim_steps = params['iter']
    paths = helper.compute_paths(args, params)

    if resume:
        print(
            f'In [train]: resuming training from optim_step={optim_step} - max_step: {max_optim_steps}'
        )

    # optimization loop
    while optim_step < max_optim_steps:
        # after each epoch, adjust learning rate accordingly
        current_lr = adjust_lr(
            current_lr,
            initial_lr=params['lr'],
            step=optim_step,
            epoch_steps=len(
                train_loader))  # now only supports with batch size 1
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr
        print(
            f'In [train]: optimizer learning rate adjusted to: {current_lr}\n')

        for i_batch, batch in enumerate(train_loader):
            if optim_step > max_optim_steps:
                print(
                    f'In [train]: reaching max_step or lr is zero. Terminating...'
                )
                return  # ============ terminate training if max steps reached

            begin_time = time.time()
            # forward pass
            left_batch, right_batch, extra_cond_batch = data_handler.extract_batches(
                batch, args)
            forward_output = forward_and_loss(args, params, model, left_batch,
                                              right_batch, extra_cond_batch)

            # regularize left loss
            if train_configs['reg_factor'] is not None:
                loss = train_configs['reg_factor'] * forward_output[
                    'loss_left'] + forward_output['loss_right']  # regularized
            else:
                loss = forward_output['loss']

            metrics = {'loss': loss}
            # also add left and right loss if available
            if 'loss_left' in forward_output.keys():
                metrics.update({
                    'loss_right': forward_output['loss_right'],
                    'loss_left': forward_output['loss_left']
                })

            # backward pass and optimizer step
            model.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'In [train]: Step: {optim_step} => loss: {loss.item():.3f}')

            # validation loss
            if (params['monitor_val'] and optim_step % params['val_freq']
                    == 0) or current_lr == 0:
                val_loss_mean, _ = calc_val_loss(args, params, model,
                                                 val_loader)
                metrics['val_loss'] = val_loss_mean
                print(
                    f'====== In [train]: val_loss mean: {round(val_loss_mean, 3)}'
                )

            # tracking metrics
            if args.use_comet:
                for key, value in metrics.items():  # track all metric values
                    comet_tracker.track_metric(key, round(value.item(), 3),
                                               optim_step)

            # saving samples
            if (optim_step % params['sample_freq'] == 0) or current_lr == 0:
                samples_path = paths['samples_path']
                helper.make_dir_if_not_exists(samples_path)
                sampled_images = models.take_samples(args, params, model,
                                                     reverse_cond)
                utils.save_image(
                    sampled_images,
                    f'{samples_path}/{str(optim_step).zfill(6)}.png',
                    nrow=10)
                print(
                    f'\nIn [train]: Sample saved at iteration {optim_step} to: \n"{samples_path}"\n'
                )

            # saving checkpoint
            if (optim_step > 0 and optim_step % params['checkpoint_freq']
                    == 0) or current_lr == 0:
                checkpoints_path = paths['checkpoints_path']
                helper.make_dir_if_not_exists(checkpoints_path)
                helper.save_checkpoint(checkpoints_path, optim_step, model,
                                       optimizer, loss, current_lr)
                print("In [train]: Checkpoint saved at iteration", optim_step,
                      '\n')

            optim_step += 1
            end_time = time.time()
            print(f'Iteration took: {round(end_time - begin_time, 2)}')
            helper.show_memory_usage()
            print('\n')

            if current_lr == 0:
                print(
                    'In [train]: current_lr = 0, terminating the training...')
                sys.exit(0)
Esempio n. 22
0
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.every = args_new.every
            args.evaluate = args_new.evaluate
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.every = args_new.every
            args.sparse_depth_source = args_new.sparse_depth_source
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')

    # model
    if args.type_feature == "sq":
        if args.instancewise:
            model = DepthCompletionNetQSquareNet(args).to(device)
        else:
            model = DepthCompletionNetQSquare(args).to(device)
    elif args.type_feature == "lines":
        model = DepthCompletionNetQ(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'], strict=False)
        #optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")
    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=1,
    #     shuffle=False,
    #     num_workers=2,
    #     pin_memory=True)  # set batch size to be 1 for validation
    # print("\t==> val_loader size:{}".format(len(val_loader)))
    val_dataset_sub = torch.utils.data.Subset(val_dataset, torch.arange(1000))
    val_loader = torch.utils.data.DataLoader(
        val_dataset_sub,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger, checkpoint['epoch'])
        return

    # for name, param in model.named_parameters():
    # #for name, param in model.state_dict().items():
    #     #print(name, param.shape)
    #     if "parameter" not in name:
    #     #if 1:
    #         h = param.register_hook(lambda grad: grad * 0)  # double the gradient

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("\n\n=> starting training epoch {} .. \n\n".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger, epoch)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory, args.type_feature)
Esempio n. 23
0
            loss = criterion(out, y_valid)

            losses.update(loss.item(), x_valid.size(0))
            batch_time.update(time.time() - end_time)
            end_time = time.time()

            print(
                f'Valid Epoch [{epoch + 1}/{NUM_EPOCH}] [{idx}/{len(valid_loader)}]\t'
                f' Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f' Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                f' Loss {losses.val:.4f} ({losses.avg:.4f}) ')

    history['valid_loss'].append(losses.avg)

    is_best = losses.avg < best_loss
    best_loss = min(losses.avg, best_loss)
    helper.save_checkpoint(
        {
            'epoch': epoch + 1,
            'batch_size': BSIZE,
            'learning_rate': LRATE,
            'total_clazz': len(CLAZZ),
            'class_to_idx': iris_dataset.class_to_idx,
            'labels': CLAZZ,
            'history': history,
            'arch': 'IrisNet',
            'state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimiz1er': optimizer.state_dict(),
        }, is_best)
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    torch.set_printoptions(profile="full")
    table_is = np.zeros(400)
    for i, batch_data in enumerate(loader):

        sparse_depth_pathname = batch_data['d_path'][0]
        print(sparse_depth_pathname)
        del batch_data['d_path']
        print("i: ", i)
        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None

        # adjust depth for features
        depth_adjust = args.depth_adjust
        adjust_features = False

        if depth_adjust and args.use_d:
            if args.type_feature == "sq":
                if args.use_rgb:
                    depth_new, alg_mode, feat_mode, features, shape = depth_adjustment(
                        batch_data['d'], args.test_mode, args.feature_mode,
                        args.feature_num, args.rank_file_global_sq,
                        adjust_features, i, model_orig, args.seed,
                        batch_data['rgb'])
                else:
                    depth_new, alg_mode, feat_mode, features, shape = depth_adjustment(
                        batch_data['d'], args.test_mode, args.feature_mode,
                        args.feature_num, args.rank_file_global_sq,
                        adjust_features, i, model_orig, args.seed)
            elif args.type_feature == "lines":
                depth_new, alg_mode, feat_mode, features = depth_adjustment_lines(
                    batch_data['d'], args.test_mode, args.feature_mode,
                    args.feature_num, args.rank_file_global_sq, i, model_orig,
                    args.seed)

            batch_data['d'] = torch.Tensor(depth_new).unsqueeze(0).unsqueeze(
                1).to(device)
        data_time = time.time() - start
        start = time.time()
        if mode == "train":
            pred = model(batch_data)
        else:
            with torch.no_grad():
                pred = model(batch_data)
        # im = batch_data['d'].detach().cpu().numpy()
        # im_sq = im.squeeze()
        # plt.figure()
        # plt.imshow(im_sq)
        # plt.show()
        # for i in range(im_sq.shape[0]):
        #     print(f"{i} - {np.sum(im_sq[i])}")


#        pred = pred +0.155
#        gt = gt+0.155
# compute loss
        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()
            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)
                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]
                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)
                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))
            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop
            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        gpu_time = time.time() - start

        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            print(f"rmse: {result.rmse:,}")
            if result.rmse < 6000:
                print("good rmse")
            elif result.rmse > 12000:
                print("bad rmse")
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

        # save log and checkpoint
        every = 999 if mode == "val" else 200
        if i % every == 0 and i != 0:

            print(
                f"test settings (main_orig eval): {args.type_feature} {args.test_mode} {args.feature_mode} {args.feature_num}"
            )
            avg = logger.conditional_save_info(mode, average_meter, epoch)
            is_best = logger.rank_conditional_save_best(mode, avg, epoch)
            if is_best and not (mode == "train"):
                logger.save_img_comparison_as_best(mode, epoch)
            logger.conditional_summarize(mode, avg, is_best)

            if mode != "val":
                #if 1:
                helper.save_checkpoint({  # save checkpoint
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'best_result': logger.best_result,
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                }, is_best, epoch, logger.output_directory, args.type_feature, args.test_mode, args.feature_num, args.feature_mode, args.depth_adjust, i, every, "scratch")

        # draw features
        # run_info = [args.type_feature, alg_mode, feat_mode, model_orig]
        # if batch_data['rgb'] != None and 1 and (i % 1) == 0:
        #     draw("sq", batch_data['rgb'], batch_data['d'], features, shape[1], run_info, i, result)

    return avg, is_best
def train_model(cust_model,
                dataloaders,
                criterion,
                optimizer,
                num_epochs,
                scheduler=None):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    start_time = time.time()
    val_acc_history = []
    best_acc = 0.0
    best_model_wts = copy.deepcopy(cust_model)
    best_optimizer_wts = optim.Adam(best_model_wts.parameters(), lr=0.0001)
    best_optimizer_wts.load_state_dict(optimizer.state_dict())
    start_epoch = args["lastepoch"] + 1
    if (start_epoch > 1):
        filepath = "./checkpoint_epoch" + str(args["lastepoch"]) + ".pth"
        #filepath="ResNet34watershedplus_linknet_50.pt"
        cust_model, optimizer = load_checkpoint(cust_model, filepath)
        #cust_model = load_model(cust_model,filepath)
    for epoch in range(start_epoch - 1, num_epochs, 1):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        print("_" * 15)
        for phase in ["train", "valid"]:
            if phase == "train":
                cust_model.train()
            if phase == "valid":
                cust_model.eval()
            running_loss = 0.0
            jaccard_acc = 0.0
            jaccard_acc_inter = 0.0
            jaccard_acc_contour = 0.0
            dice_loss = 0.0

            for input_img, labels, inter, contours in tqdm(
                    dataloaders[phase], total=len(dataloaders[phase])):
                #input_img = input_img.cuda() if use_cuda else input_img
                #labels = labels.cuda() if use_cuda else labels
                #inter = inter.cuda() if use_cuda else inter
                input_img = input_img.to(device)
                labels = labels.to(device)
                inter = inter.to(device)
                contours = contours.to(device)
                label_true = torch.cat([labels, inter, contours], 1)
                #label_true=labels
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    out = cust_model(input_img)
                    #preds = torch.sigmoid(out)
                    preds = out
                    #print(preds.shape)
                    loss = criterion(preds, label_true)
                    loss = loss.mean()

                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * input_img.size(0)
                #print(labels.shape)
                #preds=torch.FloatTensor(preds)
                #print(preds)
                preds = torch.cat(preds)  #for multiGPU
                #print(preds.shape)

                jaccard_acc += jaccard(
                    labels.to('cpu'), torch.sigmoid(preds.to('cpu'))
                )  # THIS IS THE ONE THAT STILL IS ACCUMULATION IN ONLY ONE GPU
                jaccard_acc_inter += jaccard(inter.to('cpu'),
                                             torch.sigmoid(preds.to('cpu')))
                jaccard_acc_contour += jaccard(contours.to('cpu'),
                                               torch.sigmoid(preds.to('cpu')))

                #dice_acc += dice(labels, preds)

            epoch_loss = running_loss / len(dataloaders[phase])
            print("| {} Loss: {:.4f} |".format(phase, epoch_loss))
            aver_jaccard = jaccard_acc / len(dataloaders[phase])
            aver_jaccard_inter = jaccard_acc_inter / len(dataloaders[phase])
            aver_jaccard_contour = jaccard_acc_contour / len(
                dataloaders[phase])
            #aver_dice = dice_acc / len(dataloaders[phase])
            #print("| {} Loss: {:.4f} | Jaccard Average Acc: {:.4f} | ".format(phase, epoch_loss, aver_jaccard))
            print(
                "| {} Loss: {:.4f} | Jaccard Average Acc: {:.4f} | Jaccard Average Acc inter: {:.4f}  | Jaccard Average Acc contour: {:.4f}| "
                .format(phase, epoch_loss, aver_jaccard, aver_jaccard_inter,
                        aver_jaccard_contour))
            print("_" * 15)
            if phase == "valid" and aver_jaccard > best_acc:
                best_acc = aver_jaccard
                best_acc_inter = aver_jaccard_inter  ## aver_jaccard_inter
                best_epoch_loss = epoch_loss
                #best_model_wts = copy.deepcopy(cust_model.state_dict)
                best_model_wts = copy.deepcopy(cust_model)
                best_optimizer_wts = optim.Adam(best_model_wts.parameters(),
                                                lr=0.0001)
                best_optimizer_wts.load_state_dict(optimizer.state_dict())
            if phase == "valid":
                val_acc_history.append(aver_jaccard)
        print("^" * 15)
        save_checkpoint(best_model_wts, best_optimizer_wts, epoch + 1,
                        best_epoch_loss, best_acc, best_acc_inter)
        print(" ")
        scheduler.step()
    time_elapsed = time.time() - start_time
    print("Training Complete in {:.0f}m {:.0f}s".format(
        time_elapsed // 60, time_elapsed % 60))
    #print("Best Validation Accuracy: {:.4f}".format(best_acc))
    #este no#best_model_wts = copy.deepcopy(cust_model.state_dict())
    cust_model.load_state_dict(best_model_wts.state_dict())
    return cust_model, val_acc_history
Esempio n. 26
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        print("=> creating model '{}'".format(args.arch))

    if args.arch == 'alexnet':
        model = alexnet(pretrained=args.pretrained)
    elif args.arch == 'squeezenet1_0':
        model = squeezenet1_0(pretrained=args.pretrained)
    elif args.arch == 'squeezenet1_1':
        model = squeezenet1_1(pretrained=args.pretrained)
    elif args.arch == 'densenet121':
        model = densenet121(pretrained=args.pretrained)
    elif args.arch == 'densenet169':
        model = densenet169(pretrained=args.pretrained)
    elif args.arch == 'densenet201':
        model = densenet201(pretrained=args.pretrained)
    elif args.arch == 'densenet161':
        model = densenet161(pretrained=args.pretrained)
    elif args.arch == 'vgg11':
        model = vgg11(pretrained=args.pretrained)
    elif args.arch == 'vgg13':
        model = vgg13(pretrained=args.pretrained)
    elif args.arch == 'vgg16':
        model = vgg16(pretrained=args.pretrained)
    elif args.arch == 'vgg19':
        model = vgg19(pretrained=args.pretrained)
    elif args.arch == 'vgg11_bn':
        model = vgg11_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg13_bn':
        model = vgg13_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg16_bn':
        model = vgg16_bn(pretrained=args.pretrained)
    elif args.arch == 'vgg19_bn':
        model = vgg19_bn(pretrained=args.pretrained)
    elif args.arch == 'resnet18':
        model = resnet18(pretrained=args.pretrained)
    elif args.arch == 'resnet34':
        model = resnet34(pretrained=args.pretrained)
    elif args.arch == 'resnet50':
        model = resnet50(pretrained=args.pretrained)
    elif args.arch == 'resnet101':
        model = resnet101(pretrained=args.pretrained)
    elif args.arch == 'resnet152':
        model = resnet152(pretrained=args.pretrained)
    else:
        raise NotImplementedError

    # use cuda
    model.cuda()
    # model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    # optionlly resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    # cudnn.benchmark = True

    # Data loading
    train_loader, val_loader = data_loader(args.data, args.batch_size,
                                           args.workers, args.pin_memory)

    if args.evaluate:
        validate(val_loader, model, criterion, args.print_freq)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch,
              args.print_freq)

        # evaluate on validation set
        prec1, prec5 = validate(val_loader, model, criterion, args.print_freq)

        # remember the best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, args.arch + '.pth')
Esempio n. 27
0
def main():
    global args
    if args.partial_train == 'yes':  # train on a part of the whole train set
        print(
            "Can't use partial train here. It is used only for test check. Exit..."
        )
        return

    if args.test != "yes":
        print(
            "This main should use only for testing, but test=yes wat not given. Exit..."
        )
        return

    print("Evaluating test set with main_test:")
    whole_ts = time.time()
    checkpoint = None
    is_eval = False
    if args.evaluate:  # test a finished model
        args_new = args  # copies
        if os.path.isfile(args.evaluate):  # path is an existing regular file
            print("=> loading finished model from '{}' ... ".format(
                args.evaluate),
                  end='')  # "end=''" disables the newline
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.save_images = args_new.save_images
            args.result = args_new.result
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint from '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters(
        )  # "_, p" is a direct analogy to an assignment statement k, _ = (0, 1). Unpack a tuple object
        if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    [f'{k:<20}: {v}' for k, v in model.__dict__.items()]

    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(
        model
    )  # make the model run parallelly: splits your data automatically and sends job orders to multiple models on several GPUs.
    # After each model finishes their job, DataParallel collects and merges the results before returning it to you

    # data loading code
    print("=> creating data loaders ... ")
    if not is_eval:  # we're not evaluating
        train_dataset = KittiDepth('train',
                                   args)  # get the paths for the files
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)  # load them
        print("\t==> train_loader size:{}".format(len(train_loader)))

    if args_new.test == "yes":  # will take the data from the "test" folders
        val_dataset = KittiDepth('test', args)
        is_test = 'yes'
    else:
        val_dataset = KittiDepth('val', args)
        is_test = 'no'
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args, is_test)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")  # logger records sequential data to a log file

    # main code - run the NN
    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    print("=> starting model training ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> start training epoch {}".format(epoch) +
              "/{}..".format(args.epochs))
        train_ts = time.time()
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
        helper.save_checkpoint({  # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer': optimizer.state_dict(),
            'args': args,
        }, is_best, epoch, logger.output_directory)
        print("finish training epoch {}, time elapsed {:.2f} hours, \n".format(
            epoch, (time.time() - train_ts) / 3600))
    last_checkpoint = os.path.join(
        logger.output_directory, 'checkpoint-' + str(epoch) + '.pth.tar'
    )  # delete last checkpoint because we have the best_model and we dont need it
    os.remove(last_checkpoint)
    print("finished model training, time elapsed {0:.2f} hours, \n".format(
        (time.time() - whole_ts) / 3600))
Esempio n. 28
0
import argparse
import helper

parser = argparse.ArgumentParser()
parser.add_argument('data_dir',nargs = '?',type = str, default = './flowers/')
parser.add_argument('--gpu',dest = 'gpu',action = 'store_true',default = False)
parser.add_argument('--save_dir',dest = 'save_dir',action = 'store',default = './checkpoint.pth')
parser.add_argument('--arch',dest = 'arch',action = 'store',default ='vgg16')
parser.add_argument('--learning_rate',dest ='learning_rate',action = 'store',default = 0.001,type = float)
parser.add_argument('--hidden_units',dest = 'hidden_units',action = 'store',default = 1024, type = int )
parser.add_argument('--epochs',dest = 'epochs',action = 'store',default = 20,type = int)

args = parser.parse_args()

# load data
train_data,trainloader, testloader,validloader = helper.load_data()

# build model
print(args.gpu)
print(args.arch)
print(type(args.hidden_units))
print(type(args.learning_rate))
model,device,criterion,optimizer = helper.build_model(args.gpu,args.arch,args.hidden_units,args.learning_rate)

# train model
helper.train_model(args.epochs,trainloader,validloader,model,device,criterion,optimizer)

# save the trained model
helper.save_checkpoint(model,args.epochs,args.arch,optimizer,train_data)

Esempio n. 29
0
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    block_average_meter = AverageMeter()
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
    else:
        model.eval()
        lr = 0

    torch.set_printoptions(profile="full")
    for i, batch_data in enumerate(loader):

        name = batch_data['name'][0]
        print(name)
        del batch_data['name']
        print("i: ", i)
        # each batch data is 1 and has three keys d, gt, g and dim [1, 352, 1216]
        start = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }
        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None

        # if args.type_feature=="sq":
        #     depth_adjustment(gt, False)

        data_time = time.time() - start

        start = time.time()
        if mode == "train":
            pred = model(batch_data)
        else:
            with torch.no_grad():
                pred = model(batch_data)

        vis=False
        if vis:
            im = batch_data['gt'].detach().cpu().numpy()
            im_sq = im.squeeze()
            plt.figure()
            plt.imshow(im_sq)
            plt.show()
            # for i in range(im_sq.shape[0]):
            #     print(f"{i} - {np.sum(im_sq[i])}")

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            if 'sparse' in args.train_mode:
                depth_loss = depth_criterion(pred, batch_data['d'])
                print("d pts: ", len(torch.where(batch_data['d']>0)[0]))
                mask = (batch_data['d'] < 1e-3).float()
            elif 'dense' in args.train_mode:
                depth_loss = depth_criterion(pred, gt)
                mask = (gt < 1e-3).float()

            # Loss 2: the self-supervised photometric loss
            if args.use_pose:
                # create multi-scale pyramids
                pred_array = helper.multiscale(pred)
                rgb_curr_array = helper.multiscale(batch_data['rgb'])
                rgb_near_array = helper.multiscale(batch_data['rgb_near'])
                if mask is not None:
                    mask_array = helper.multiscale(mask)
                num_scales = len(pred_array)

                # compute photometric loss at multiple scales
                for scale in range(len(pred_array)):
                    pred_ = pred_array[scale]
                    rgb_curr_ = rgb_curr_array[scale]
                    rgb_near_ = rgb_near_array[scale]
                    mask_ = None
                    if mask is not None:
                        mask_ = mask_array[scale]

                    # compute the corresponding intrinsic parameters
                    height_, width_ = pred_.size(2), pred_.size(3)
                    intrinsics_ = kitti_intrinsics.scale(height_, width_)

                    # inverse warp from a nearby frame to the current frame
                    warped_ = homography_from(rgb_near_, pred_,
                                              batch_data['r_mat'],
                                              batch_data['t_vec'], intrinsics_)
                    photometric_loss += photometric_criterion(
                        rgb_curr_, warped_, mask_) * (2**(scale - num_scales))

            # Loss 3: the depth smoothness loss
            smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0

            # backprop



            loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
            optimizer.zero_grad()
            loss.backward()
            zero_params(model)
            optimizer.step()



        gpu_time = time.time() - start

        # counting pixels in each bin
        #binned_pixels = np.load("value.npy", allow_pickle=True)
        #print(len(binned_pixels))

        if (i % 1 == 0 and args.evaluate and args.instancewise) or\
                (i % args.every == 0 and not args.evaluate and not args.instancewise): # global training
            #    print(model.module.conv4[5].conv1.weight[0])
            # print(model.conv4.5.bn2.weight)
            # print(model.module.parameter.grad)
            #print("*************swiches:")
            torch.set_printoptions(precision=7, sci_mode=False)

            if model.module.phi is not None:
                mmp = 1000 * model.module.phi
                phi = F.softplus(mmp)

                S = phi / torch.sum(phi)
                #print("S", S[1, -10:])
                S_numpy= S.detach().cpu().numpy()

            if args.instancewise:

                global Ss
                if "Ss" not in globals():
                    Ss = []
                    Ss.append(S_numpy)
                else:
                    Ss.append(S_numpy)

            # GLOBAL
            if (i % args.every ==0  and not args.evaluate and not args.instancewise and model.module.phi is not None):

                np.set_printoptions(precision=4)

                switches_2d_argsort = np.argsort(S_numpy, None) # 2d to 1d sort torch.Size([9, 31])
                switches_2d_sort = np.sort(S_numpy, None)
                print("Switches: ")
                print(switches_2d_argsort[:10])
                print(switches_2d_sort[:10])
                print("and")
                print(switches_2d_argsort[-10:])
                print(switches_2d_sort[-10:])

                ##### saving global ranks
                global_ranks_path = lambda \
                    ii: f"ranks/{args.type_feature}/global/{folder_and_name[0]}/Ss_val_{folder_and_name[1]}_iter_{ii}.npy"
                global old_i
                if ("old_i" in globals()):
                    print("old_i")
                    if os.path.isfile(global_ranks_path(old_i)):
                        os.remove(global_ranks_path(old_i))

                folder_and_name = args.resume.split(os.sep)[-2:]
                os.makedirs(f"ranks/{args.type_feature}/global/{folder_and_name[0]}", exist_ok=True)
                np.save(global_ranks_path(i), S_numpy)
                old_i = i
                print("saving ranks")

                if args.type_feature == "sq":

                    hor = switches_2d_argsort % S_numpy.shape[1]
                    ver = np.floor(switches_2d_argsort // S_numpy.shape[1])
                    print(ver[:10],hor[:10])
                    print("and")
                    print(ver[-10:], hor[-10:])


        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
            [
                m.update(result, gpu_time, data_time, mini_batch_size)
                for m in meters
            ]
            logger.conditional_print(mode, i, epoch, lr, len(loader),
                                     block_average_meter, average_meter)
            logger.conditional_save_img_comparison(mode, i, batch_data, pred,
                                                   epoch)
            logger.conditional_save_pred(mode, i, pred, epoch)

        draw=False
        if draw:
            ma = batch_data['rgb'].detach().cpu().numpy().squeeze()
            ma  = np.transpose(ma, axes=[1, 2, 0])
           # ma = np.uint8(ma)
            #ma2 = Image.fromarray(ma)
            ma2 = Image.fromarray(np.uint8(ma)).convert('RGB')
            # create rectangle image
            img1 = ImageDraw.Draw(ma2)

            if args.type_feature == "sq":
                size=40
                print_square_num = 20
                for ii in range(print_square_num):
                    s_hor=hor[-ii].detach().cpu().numpy()
                    s_ver=ver[-ii].detach().cpu().numpy()
                    shape = [(s_hor * size, s_ver * size), ((s_hor + 1) * size, (s_ver + 1) * size)]
                    img1.rectangle(shape, outline="red")

                    tim = time.time()
                    lala = ma2.save(f"switches_photos/squares/squares_{tim}.jpg")
                    print("saving")
            elif args.type_feature == "lines":
                print_square_num = 20
                r=1
                parameter_mask = np.load("../kitti_pixels_to_lines.npy", allow_pickle=True)

                # for m in range(10,50):
                #     im = Image.fromarray(parameter_mask[m]*155)
                #     im = im.convert('1')  # convert image to black and white
                #     im.save(f"switches_photos/lala_{m}.jpg")


                for ii in range(print_square_num):
                     points = parameter_mask[ii]
                     y = np.where(points==1)[0]
                     x = np.where(points == 1)[1]

                     for p in range(len(x)):
                         img1.ellipse((x[p] - r, y[p] - r, x[p] + r, y[p] + r), fill=(255, 0, 0, 0))

                tim = time.time()
                lala = ma2.save(f"switches_photos/lines/lines_{tim}.jpg")
                print("saving")

        every = args.every
        if i % every ==0:

            print("saving")
            avg = logger.conditional_save_info(mode, average_meter, epoch)
            is_best = logger.rank_conditional_save_best(mode, avg, epoch)
            #is_best = True #saving all the checkpoints
            if is_best and not (mode == "train"):
                logger.save_img_comparison_as_best(mode, epoch)
            logger.conditional_summarize(mode, avg, is_best)

            if mode != "val":
                helper.save_checkpoint({  # save checkpoint
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'best_result': logger.best_result,
                    'optimizer': optimizer.state_dict(),
                    'args': args,
                }, is_best, epoch, logger.output_directory, args.type_feature, i, every, qnet=True)

    if args.evaluate and args.instancewise:
        #filename = os.path.split(args.evaluate)[1]
        Ss_numpy = np.array(Ss)
        folder_and_name = args.evaluate.split(os.sep)[-3:]
        os.makedirs(f"ranks/instance/{folder_and_name[0]}", exist_ok=True)
        os.makedirs(f"ranks/instance/{folder_and_name[0]}/{folder_and_name[1]}", exist_ok=True)
        np.save(f"ranks/instance/{folder_and_name[0]}/{folder_and_name[1]}/Ss_val_{folder_and_name[2]}.npy", Ss)

    return avg, is_best
Esempio n. 30
0
    def train(epoch):
        data_loader.train()
        train_loss, curr_batch_size_accum = 0, 0
        start = time.time()
        for batch_idx, curr_batch_size, batch in data_loader:
            train_step, batch_loss, train_elbo_per_sample, train_likelihood, train_kl, curr_temp =\
                sess.run([train_step_tf, *train_out_list], feed_dict = input_dict_func(batch))
            train_loss += curr_batch_size * batch_loss
            curr_batch_size_accum += curr_batch_size

            if batch_idx % global_args.log_interval == 0:
                end = time.time()
                print(
                    'Train Epoch: {} [{:7d} ()]\tLoss: {:.6f}\tLikelihood: {:.6f}\tKL: {:.6f}\tTime: {:.3f}, Temperature: {:.3f}'
                    .format(epoch, batch_idx * curr_batch_size, batch_loss,
                            train_likelihood, train_kl, (end - start),
                            curr_temp))
                with open(global_args.exp_dir + "training_traces.txt",
                          "a") as text_file:
                    trace_string = str(batch_loss) + ', ' + str(
                        train_likelihood) + ', ' + str(train_kl) + ', ' + str(
                            curr_temp) + '\n'
                    text_file.write(trace_string)
                start = time.time()

        summary_str = sess.run(merged_summaries,
                               feed_dict=input_dict_func(batch))
        summary_writer.add_summary(summary_str,
                                   (tf.train.global_step(sess, global_step)))

        if epoch % 10 == 0:
            print('====> Epoch: {} Average loss: {:.4f}'.format(
                epoch, train_loss / curr_batch_size_accum))

            distributions.visualizeProductDistribution(
                sess,
                input_dict_func(batch),
                batch,
                obs_dist,
                sample_obs_dist,
                save_dir=global_args.exp_dir + 'Visualization/',
                postfix='train')
            if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
                helper.visualize_datasets(sess,
                                          input_dict_func(batch),
                                          data_loader.dataset,
                                          obs_sample_out_tf,
                                          latent_sample_out_tf,
                                          save_dir=global_args.exp_dir +
                                          'Visualization/',
                                          postfix=str(epoch))

            checkpoint_path1 = global_args.exp_dir + 'checkpoint/'
            checkpoint_path2 = global_args.exp_dir + 'checkpoint2/'
            print('====> Saving checkpoint. Epoch: ', epoch)
            start_tmp = time.time()
            helper.save_checkpoint(saver, sess, global_step, checkpoint_path1)
            end_tmp = time.time()
            print(
                'Checkpoint path: ' + checkpoint_path1 + '   ====> It took: ',
                end_tmp - start_tmp)
            if epoch % 60 == 0:
                print('====> Saving checkpoint backup. Epoch: ', epoch)
                start_tmp = time.time()
                helper.save_checkpoint(saver, sess, global_step,
                                       checkpoint_path2)
                end_tmp = time.time()
                print(
                    'Checkpoint path: ' + checkpoint_path2 +
                    '   ====> It took: ', end_tmp - start_tmp)