def train(data_dir, arch, hidden_units, output_size, dropout, lr, epochs, gpu,
          checkpoint):

    print(
        'Dir: {},\t Arch:{},\t HiddenUints: {},\t lr: {},\t Epochs: {},\t gpu: {}\n'
        .format(data_dir, arch, hidden_units, lr, epochs, gpu))

    print('Loading Images from Directory...')
    trainloader, validloader, testloader, class_to_idx = get_loaders(data_dir)
    print('Images Loaded.\n')

    print('Building the Model...')
    model, criterion, optimizer = build_model(arch, hidden_units, output_size,
                                              dropout, lr)
    print('Model Built.\n')

    print('Beggining the Training...')
    model, optimizer = training(model, trainloader, validloader, epochs, 20,
                                criterion, optimizer, gpu)
    print('Training Done.\n')

    if checkpoint:
        print('Saving the Checkpoint...')
        save_checkpoint(checkpoint, model, optimizer, arch,
                        model.classifier[0].in_features, output_size,
                        hidden_units, dropout, class_to_idx, epochs, lr)
        print('Done.')
    def save_checkpoint(self):
        '''
        Saves the model state
        '''
        # Save latest checkpoint (constantly overwriting itself)
        checkpoint_path = osp.join(self.checkpoint_dir,
                                   'checkpoint_latest.pth')

        # Actually saves the latest checkpoint and also updating the file holding the best one
        util.save_checkpoint(
            {
                'epoch': self.epoch + 1,
                'experiment': self.name,
                'state_dict': self.network.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'loss_meter': self.loss.to_dict(),
                'best_d1_inlier': self.best_d1_inlier
            },
            self.is_best,
            filename=checkpoint_path)

        # Copies the latest checkpoint to another file stored for each epoch
        history_path = osp.join(self.checkpoint_dir,
                                'checkpoint_{:03d}.pth'.format(self.epoch + 1))
        shutil.copyfile(checkpoint_path, history_path)
        print('Checkpoint saved')
Exemple #3
0
def main():
    trainset, valset, testset = cath_dataset(
        1800, jsonl_file=sys.argv[1])  # batch size = 1800 residues
    optimizer = tf.keras.optimizers.Adam()
    model = make_model()

    model_id = int(datetime.timestamp(datetime.now()))

    NUM_EPOCHS = 100
    loop_func = util.loop
    best_epoch, best_val = 0, np.inf

    for epoch in range(NUM_EPOCHS):
        loss, acc, confusion = loop_func(trainset,
                                         model,
                                         train=True,
                                         optimizer=optimizer)
        util.save_checkpoint(model, optimizer, model_id, epoch)
        print('EPOCH {} TRAIN {:.4f} {:.4f}'.format(epoch, loss, acc))
        util.save_confusion(confusion)
        loss, acc, confusion = loop_func(valset, model, train=False)
        if loss < best_val:
            best_epoch, best_val = epoch, loss
        print('EPOCH {} VAL {:.4f} {:.4f}'.format(epoch, loss, acc))
        util.save_confusion(confusion)

    # Test with best validation loss
    path = util.models_dir.format(str(model_id).zfill(3), str(epoch).zfill(3))
    util.load_checkpoint(model, optimizer, path)
    loss, acc, confusion = loop_func(testset, model, train=False)
    print('EPOCH TEST {:.4f} {:.4f}'.format(loss, acc))
    util.save_confusion(confusion)
Exemple #4
0
def main():
    trainset, valset, testset = rocklin_dataset(
        32)  # batch size = 1800 residues
    optimizer = tf.keras.optimizers.Adam()
    model = make_model()

    model_id = int(datetime.timestamp(datetime.now()))

    NUM_EPOCHS = 50
    loop_func = loop
    best_epoch, best_val = 0, np.inf

    for epoch in range(NUM_EPOCHS):
        loss = loop_func(trainset, model, train=True, optimizer=optimizer)
        print('EPOCH {} training loss: {}'.format(epoch, loss))
        save_checkpoint(model, optimizer, model_id, epoch)
        print('EPOCH {} TRAIN {:.4f}'.format(epoch, loss))
        #util.save_confusion(confusion)
        loss = loop_func(valset, model, train=False, val=False)
        print(' EPOCH {} validation loss: {}'.format(epoch, loss))
        if loss < best_val:
            #Could play with this parameter here. Instead of saving best NN based on loss
            #we could save it based on precision/auc/recall/etc.
            best_epoch, best_val = epoch, loss
        print('EPOCH {} VAL {:.4f}'.format(epoch, loss))
        #util.save_confusion(confusion)

# Test with best validation loss
    path = models_dir.format(str(model_id).zfill(3), str(epoch).zfill(3))
    load_checkpoint(model, optimizer, path)
    loss, tp, fp, tn, fn, acc, prec, recall, auc, y_pred, y_true = loop_func(
        testset, model, train=False, val=True)
    print('EPOCH TEST {:.4f} {:.4f}'.format(loss, acc))
    #util.save_confusion(confusion)
    return loss, tp, fp, tn, fn, acc, prec, recall, auc, y_pred, y_true
Exemple #5
0
 def __save_progress(self, total_epochs, model, loss):
     with TimeMeasure(enter_msg="Saving progress...",
                      writer=logger.debug,
                      print_enabled=self.__print_enabled):
         path = p_join("trained_models", self.__name,
                       "epoch-{:05d}.pt".format(total_epochs))
         save_checkpoint(path, total_epochs, model, loss,
                         self.__environment)
Exemple #6
0
def quantize_process(model):
    print('------------------------------- accuracy before weight sharing ----------------------------------')
    acc = util.validate(val_loader, model, args)
    util.log(f"{args.save_dir}/{args.log}", f"accuracy before weight sharing\t{acc}")

    print('------------------------------- accuacy after weight sharing -------------------------------')
    
    tempfc1=torch.index_select(model.fc1.weight, 0, model.invrow1.cuda())
    model.fc1.weight=torch.nn.Parameter(torch.index_select(tempfc1, 1, model.invcol1.cuda()))
    tempfc2=torch.index_select(model.fc2.weight, 0, model.invrow2.cuda())
    model.fc2.weight=torch.nn.Parameter(torch.index_select(tempfc2, 1, model.invcol2.cuda()))
    tempfc3=torch.index_select(model.fc3.weight, 0, model.invrow3.cuda())
    model.fc3.weight=torch.nn.Parameter(torch.index_select(tempfc3, 1, model.invcol3.cuda()))
    
    old_weight_list, new_weight_list, quantized_index_list, quantized_center_list = apply_weight_sharing(model, args.model_mode, args.bits)
    
    temp1=torch.index_select(model.fc1.weight, 0, model.rowp1.cuda())
    model.fc1.weight=torch.nn.Parameter(torch.index_select(temp1, 1, model.colp1.cuda()))
    temp2=torch.index_select(model.fc2.weight, 0, model.rowp2.cuda())
    model.fc2.weight=torch.nn.Parameter(torch.index_select(temp2, 1, model.colp2.cuda()))
    temp3=torch.index_select(model.fc3.weight, 0, model.rowp3.cuda())
    model.fc3.weight=torch.nn.Parameter(torch.index_select(temp3, 1, model.colp3.cuda()))
    
    acc = util.validate(val_loader, model, args)
    util.save_checkpoint({
        'state_dict': model.state_dict(),
        'best_prec1': acc,
    }, True, filename=os.path.join(args.save_dir, 'checkpoint_{}_alpha_{}.tar'.format('quantized',args.alpha)))

    util.log(f"{args.save_dir}/{args.log}", f"weight\t{args.save_dir}/{args.out_quantized_folder}")
    util.log(f"{args.save_dir}/{args.log}", f"model\t{args.save_dir}/model_quantized.ptmodel")
    util.log(f"{args.save_dir}/{args.log}", f"accuracy after weight sharing {args.bits}bits\t{acc}")

    util.layer2torch(f"{args.save_dir}/{args.out_quantized_folder}" , model)
    util.save_parameters(f"{args.save_dir}/{args.out_quantized_folder}", new_weight_list)
    
    print('------------------------------- retraining -------------------------------------------')

    util.quantized_retrain(model, args, quantized_index_list, quantized_center_list, train_loader, val_loader)

    acc = util.validate(val_loader, model, args)
    util.save_checkpoint({
        'state_dict': model.state_dict(),
        'best_prec1': acc,
    }, True, filename=os.path.join(args.save_dir, 'checkpoint_{}_alpha_{}.tar'.format('quantized_re',args.alpha)))

    util.layer2torch(f"{args.save_dir}/{args.out_quantized_re_folder}" , model)

    util.log(f"{args.save_dir}/{args.log}", f"weight:{args.save_dir}/{args.out_quantized_re_folder}")
    util.log(f"{args.save_dir}/{args.log}", f"model:{args.save_dir}/model_quantized_bit{args.bits}_retrain{args.reepochs}.ptmodel")
    util.log(f"{args.save_dir}/{args.log}", f"acc after qauntize and retrain\t{acc}")

    weight_list = util.parameters2list(model.children())
    util.save_parameters(f"{args.save_dir}/{args.out_quantized_re_folder}", weight_list)
    return model
Exemple #7
0
def pruning_process(model):

    print("------------------------- Before pruning --------------------------------")
    util.print_nonzeros(model, f"{args.save_dir}/{args.log}")
    accuracy = util.validate(val_loader, model, args)

    print("------------------------- pruning CNN--------------------------------------")
    model.prune_by_percentile( ['conv1'], q=100-58.0)
    model.prune_by_percentile( ['conv2'], q=100-22.0)
    model.prune_by_percentile( ['conv3'], q=100-34.0)
    model.prune_by_percentile( ['conv4'], q=100-36.0)
    model.prune_by_percentile( ['conv5'], q=100-53.0)
    model.prune_by_percentile( ['conv6'], q=100-24.0)
    model.prune_by_percentile( ['conv7'], q=100-42.0)
    model.prune_by_percentile( ['conv8'], q=100-32.0)
    model.prune_by_percentile( ['conv9'], q=100-27.0)
    model.prune_by_percentile( ['conv10'], q=100-34.0)
    model.prune_by_percentile( ['conv11'], q=100-35.0)
    model.prune_by_percentile( ['conv12'], q=100-29.0)
    model.prune_by_percentile( ['conv13'], q=100-36.0)
    print("------------------------------- After prune CNN ----------------------------")
    util.print_nonzeros(model, f"{args.save_dir}/{args.log}")

    prec1 = util.validate(val_loader, model, args)

    util.save_checkpoint({
        'state_dict': model.state_dict(),
        'best_prec1': prec1,
    }, True, filename=os.path.join(args.save_dir, 'checkpoint_{}_alpha_{}.tar'.format('pruned',args.alpha)))

    util.log(f"{args.save_dir}/{args.log}", f"weight\t{args.save_dir}/{args.out_pruned_folder}")
    util.log(f"{args.save_dir}/{args.log}", f"model\t{args.save_dir}/model_pruned.ptmodel")
    util.log(f"{args.save_dir}/{args.log}", f"prune acc\t{prec1}")
    
    util.layer2torch(f"{args.save_dir}/{args.out_pruned_folder}" , model)
    weight_list = util.parameters2list(model.children())
    util.save_parameters(f"{args.save_dir}/{args.out_pruned_folder}", weight_list)
    
    print("------------------------- start retrain after prune CNN----------------------------")
    util.initial_train(model, args, train_loader, val_loader, 'prune_re')
    
    print("------------------------- After Retraining -----------------------------")
    util.print_nonzeros(model, f"{args.save_dir}/{args.log}")
    accuracy = util.validate(val_loader, model, args)
    
    util.log(f"{args.save_dir}/{args.log}", f"weight\t{args.save_dir}/{args.out_pruned_re_folder}")
    util.log(f"{args.save_dir}/{args.log}", f"model\t{args.save_dir}/model_prune_retrain_{args.reepochs}.ptmodel")
    util.log(f"{args.save_dir}/{args.log}", f"prune and retrain acc\t{accuracy}")
    
    util.layer2torch(f"{args.save_dir}/{args.out_pruned_re_folder}" , model)
    weight_list = util.parameters2list(model.children())
    util.save_parameters(f"{args.save_dir}/{args.out_pruned_re_folder}", weight_list)

    return model
Exemple #8
0
def main():
    pprint(vars(args))

    global best_mAP
    model, optimizer = build_model()
    train_loader, val_loader, test_loader = load_data(model.module)
    criterion = nn.CrossEntropyLoss().cuda()

    if args.inference:
        inference(test_loader, model)
        return

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        mAP, prec = validate(val_loader, model, criterion)

        # remember best mAP and save checkpoint
        is_best = mAP > best_mAP
        best_mAP = max(mAP, best_mAP)

        # save checkpoint
        state = {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_mAP': best_mAP,
            'optimizer': optimizer.state_dict(),
        }
        best_path = os.path.join(args.model_save_path, 'best_models')
        if not os.path.exists(best_path):
            os.makedirs(best_path)
        class_name = train_loader.dataset.class_name
        filename = "{}/{}_{}_{}_checkpoint.pth.tar".format(
            args.model_save_path, args.arch, class_name, mAP)
        best_filename = '{}/{}_{}.pth.tar'.format(best_path, args.arch,
                                                  class_name)
        save_checkpoint(state, is_best, filename, best_filename)

        # tensorboad record
        writer.add_scalar('val_prec', prec, global_train_step)
        writer.add_scalar('val_mAP', mAP, global_train_step)
        writer.add_scalar('val_prec_epoch', prec, epoch)
        writer.add_scalar('val_mAP_epoch', mAP, epoch)
        writer.file_writer.flush()

        # print mAP
        print(' * best mAP = {best_mAP:.3f}'.format(best_mAP=best_mAP))
Exemple #9
0
    def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.wake_phi_loss_history.append(wake_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                wake_theta_loss, elbo = losses.get_wake_theta_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_theta_loss.backward()
                theta_grads = [
                    p.grad.clone() for p in generative_model.parameters()
                ]

                inference_network.zero_grad()
                wake_phi_loss = losses.get_wake_phi_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_phi_loss.backward()
                phi_grads = [p.grad for p in inference_network.parameters()]

                stats.update(theta_grads + phi_grads)
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format(
                    iteration, self.log_p_history[-1], self.kl_history[-1]))
Exemple #10
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            _, renyi = eval_gen_inf_alpha(generative_model, inference_network,
                                          self.test_data_loader,
                                          self.eval_num_particles, self.alpha)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)
            self.renyi_history.append(renyi)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                inference_network.zero_grad()
                loss, elbo = losses.get_thermo_alpha_loss(
                    generative_model, inference_network, self.test_obs,
                    self.partition, self.num_particles, self.alpha,
                    self.integration)
                loss.backward()
                stats.update([p.grad for p in generative_model.parameters()] +
                             [p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}, renyi = {:.3f}'.
                format(iteration, self.log_p_history[-1], self.kl_history[-1],
                       self.renyi_history[-1]))
Exemple #11
0
def train_epochs(resume=False, use_glove=True):
    """Train multiple opochs"""

    print('total epochs: ', cfg.EPOCHS, '; use_glove: ', use_glove)

    training_data, word_to_idx, label_to_idx = data_loader()
    model, best_acc, start_epoch = get_model(word_to_idx, label_to_idx, resume,
                                             use_glove)

    losses = []
    loss_function = nn.NLLLoss()
    if cfg.RUN_MODE == 'CNN':
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
        # optimizer = optim.SGD(model.parameters(), lr=0.1)
        # optimizer = optim.Adagrad(model.parameters(), lr=0.01, weight_decay=0.01)
    else:
        # optimizer = optim.Adam(model.parameters(), lr=0.001)
        optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=0.1)
    # optimizers below are not working
    # optimizer = optim.Adagrad(model.parameters(), lr=0.001)

    since = time.time()
    training_error_rates = []
    test_error_rates = []
    for epoch in range(1 + start_epoch, start_epoch + cfg.EPOCHS + 1):
        train_error, train_loss = train(model, loss_function, optimizer,
                                        training_data, word_to_idx)
        losses.append(train_loss)
        training_error_rates.append(train_error)
        test_error_rate = get_error_rate(model, training=False)
        test_error_rates.append(test_error_rate)
        acc = 1 - test_error_rate
        print('epoch: {}, time: {:.2f}s, cost so far: {}, accurary: {:.3f}'.
              format(epoch, (time.time() - since), train_loss.numpy(), acc))
        if acc > best_acc:
            save_checkpoint(model, acc, epoch)
            best_acc = acc

    # save all_losses
    save_to_pickle('checkpoint/all_losses.p', losses)
    save_to_pickle('checkpoint/training_error_rates.p', training_error_rates)
    save_to_pickle('checkpoint/test_error_rates.p', test_error_rates)
Exemple #12
0
def main():
    trainloaders, validloaders, testloaders, train_data = util.transform_load_image(
        pa.data_dir)
    model, criterion, optimizer = util.construct_newwork(
        pa.model_type, pa.drop_out, pa.hidden_layer_1, pa.hidden_layer_2,
        pa.output_units, pa.learning_rate)

    with active_session():
        util.test_network(model,
                          criterion,
                          optimizer,
                          trainloaders,
                          validloaders,
                          pa.epochs,
                          print_every=40,
                          steps=0)
        util.save_checkpoint(model, train_data, optimizer, pa.save_dir,
                             pa.epochs)

    print('Training completed')
Exemple #13
0
def eval_save_model(i,
                    pred_l=None,
                    target_l=None,
                    datapath=None,
                    save=False,
                    output=True):
    if args.mode == 'step-sl':
        test_f = test_step_sl
    elif args.mode == 'rl-taxo':
        test_f = test_taxo
    else:
        test_f = test_sl
    if pred_l:
        f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = evaluate(
            pred_l, target_l, output=output)
    elif datapath:
        f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = test_f(
            datapath)
    else:
        f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = test_f(
            X_train, train_ids)
    writer.add_scalar('data/micro_train', f1_aa, tree.n_update)
    writer.add_scalar('data/macro_train', f1_aa_macro, tree.n_update)
    writer.add_scalar('data/samples_train', f1_aa_s, tree.n_update)
    if not save:
        return
    if args.mode in ['rl-taxo', 'step-sl']:
        save_checkpoint(
            {
                'state_dict': policy.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, writer.file_writer.get_logdir(),
            f'epoch{i}_{f1_aa}_{f1_aa_macro}_{f1_aa_s}.pth.tar', logger, True)
    else:
        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, writer.file_writer.get_logdir(),
            f'epoch{i}_{f1_aa}_{f1_aa_macro}_{f1_aa_s}.pth.tar', logger, True)
Exemple #14
0
    def test_epoch(self, epoch):
        running_loss = 0.0
        running_corrects = 0
        self.model.eval()

        # Iterate over data.
        for inputs, labels in self.dataloaders["test"]:
            inputs = inputs.permute(1, 0, 2, 3, 4)
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            # zero the parameter gradients
            self.optimizer.zero_grad()
            # forward
            # track history if only in train
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)
                # print('-- outputs size: ', outputs.size())
                # print('-- labels size: ',labels.size())
                loss = self.criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(self.dataloaders["test"].dataset)
        epoch_acc = running_corrects.double() / len(
            self.dataloaders["test"].dataset)

        print("{} Loss: {:.4f} Acc: {:.4f}".format("test", epoch_loss,
                                                   epoch_acc))
        if self.checkpoint_path != None and epoch_acc > self.best_acc:
            self.best_acc = epoch_acc
            self.best_model_wts = copy.deepcopy(self.model.state_dict())
            save_checkpoint(self.model, self.checkpoint_path)
        # self.tb.save_value("testLoss", "test_loss", epoch, epoch_loss)
        # self.tb.save_value("testAcc", "test_acc", epoch, epoch_acc)

        return epoch_loss, epoch_acc
Exemple #15
0
def train(model, args):
    since = time.time()
    dataloader, data_set_size = load_dataset(args, mode='train')
    optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=80000, gamma=0.1)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=4,
                                               verbose=True,
                                               factor=0.5)
    checkpoint = None

    max_accuracy = 0.0
    early_stop_count = 40
    for epoch in range(args.n_epoch):
        running_loss = 0.0
        running_corrects = 0
        model.train()

        for step, (inputs, labels) in tqdm(enumerate(dataloader)):
            inputs = inputs.cuda()
            labels = labels.cuda()
            optimizer.zero_grad()
            outputs, loss = model(inputs, labels)
            loss = loss.sum()
            loss.backward()
            optimizer.step()
            _, predictions = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            corrects = torch.eq(predictions, labels)
            running_corrects += torch.sum(corrects.double())
        epoch_loss = running_loss / data_set_size
        scheduler.step(epoch_loss)
        epoch_acc = float(running_corrects) / data_set_size

        print('Epoc: {} Loss: {:.4f} Acc: {:.4f}, lr:{}'.format(
            epoch, epoch_loss, epoch_acc, optimizer.param_groups[0]['lr']))
        accuracy = evaluate(model, args, epoch, epoch > 10)
        if max_accuracy < accuracy:
            checkpoint = save_checkpoint(model, args.output_dir, epoch)
            max_accuracy = accuracy
        else:
            early_stop_count -= 1

        if early_stop_count < 0:
            break

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    return checkpoint
def train(model, optimizer, generate_data, loss_function, train_func):
    # data generation is dependend on the training task
    train_X, train_Y = generate_data(config.TRAIN_SIZE, config.SEQ_LENGTH)
    training_results = []

    logging.info('----------------- Started Training -----------------')
    for epoch in range(1, config.NUM_EPOCHS + 1):
        start_time = time.time()
        total_loss = 0
        batch_num = 0

        for batch_x, batch_y in get_batched_data(train_X, train_Y):
            total_loss += train_func(model, optimizer, loss_function, batch_x,
                                     batch_y)
            batch_num += 1

        result = 'Epoch {} \t => Loss: {} [Batch-Time = {}s]'.format(
            epoch, total_loss / batch_num, round(time.time() - start_time, 2))
        logging.info(result)

        if epoch % 25 == 0:
            logging.info("Saved model")
            save_checkpoint(model, optimizer, epoch)
Exemple #17
0
def train(model, optimizer, stats, run_args=None):
    device = model.device
    checkpoint_path = util.get_checkpoint_path(run_args)
    num_iterations_so_far = len(stats.losses)

    for iteration in range(num_iterations_so_far, run_args.num_iterations):
        # Loss
        loss = model(torch.rand(3, device=device), torch.rand((),
                                                              device=device))

        # Backprop
        loss.backward()

        # Step
        optimizer.step()
        optimizer.zero_grad()

        stats.losses.append(loss.item())

        if iteration % run_args.log_interval == 0:
            util.logging.info("it. {}/{} | loss = {:.2f}".format(
                iteration, run_args.num_iterations, loss))

        if iteration % run_args.save_interval == 0:
            util.save_checkpoint(checkpoint_path,
                                 model,
                                 optimizer,
                                 stats,
                                 run_args=run_args)

        if iteration % run_args.checkpoint_interval == 0:
            util.save_checkpoint(
                util.get_checkpoint_path(run_args,
                                         checkpoint_iteration=iteration),
                model,
                optimizer,
                stats,
                run_args=run_args,
            )

    util.save_checkpoint(
        checkpoint_path,
        model,
        optimizer,
        stats,
        run_args=run_args,
    )
def main():
    global args, best_EPE
    args = parser.parse_args()

    save_path = '{},{},{}epochs{},b{},lr{}'.format(
        args.arch, args.solver, args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size, args.lr)

    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, 'train'))
    test_writer = SummaryWriter(os.path.join(save_path, 'test'))

    # Data loading code
    transform = transforms.Compose([Normalization()])

    train_set = SpecklesDataset(csv_file='~/Train_annotations.csv',
                                root_dir='~/Train_Data/',
                                transform=transform)
    test_set = SpecklesDataset(csv_file='~/Test_annotations.csv',
                               root_dir='~/Test_Data/',
                               transform=transform)

    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               shuffle=True)

    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             shuffle=True)

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        print('=> using pre-trained model')
    else:
        network_data = None
        print('creating model')

    model = models.__dict__[args.arch](network_data).cuda()
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{
        'params': model.module.bias_parameters(),
        'weight_decay': args.bias_decay
    }, {
        'params': model.module.weight_parameters(),
        'weight_decay': args.weight_decay
    }]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.5)

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch,
                                      train_writer, scheduler)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)

        # evaluate on test dataset
        with torch.no_grad():
            EPE = validate(val_loader, model, epoch)
        test_writer.add_scalar('mean EPE', EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
                'div_flow': args.div_flow
            }, is_best, save_path)
Exemple #19
0
def train_model(model,
                criterion,
                optimizer,
                scheduler,
                dataloaders,
                dataset_sizes,
                ckpt_save_path,
                use_gpu=True,
                num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    step = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for data in dataloaders[phase]:
                step += 1
                # get the inputs
                inputs, labels = data

                # wrap them in Variable
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.data[0] * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                if step % 10 == 0:
                    print('  Step {} Loss: {:.4f}'.format(
                        step, loss.data[0] * inputs.size(0)))
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))

            # save model every 5 epochs
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            if epoch % 5 == 0:
                if not os.path.exists(ckpt_save_path):
                    os.makedirs(ckpt_save_path)
                checkpoint_name = os.path.join(ckpt_save_path,
                                               'ckpt-{}.pth.tar'.format(epoch))
                save_checkpoint(model, best_acc == epoch_acc, checkpoint_name)
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
        is_best = False
        if not_improved >= args.patience:
            logging.info(
                f"Performance did not improve for {not_improved} epochs. Stop training."
            )
            break
        not_improved += 1
        logging.info(
            f"Not improved: {not_improved} / {args.patience}: best R@5 = {best_score:.1f}, current R@5 = {recalls[1]:.1f}"
        )

    util.save_checkpoint(args, {
        "epoch": epoch,
        "state_dict": model.state_dict(),
        "recalls": recalls,
        "best_score": best_score,
        "optimizer": optimizer.state_dict(),
    },
                         is_best,
                         filename=f"model_{epoch:02d}.pth")

    logging.info(f"Start training epoch: {epoch:02d}")

    train.train(args, epoch, model, optimizer, criterion_netvlad,
                whole_train_set, query_train_set, grl_dataset)

logging.info(f"Best R@5: {best_score:.1f}")
logging.info(
    f"Trained for {epoch:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}"
)
Exemple #21
0
            break

    print("Epoch [%d] Loss: %.4f" % (epoch + 1, epoch_loss))
    log_value('loss', epoch_loss, epoch)
    log_value('lr', args.lr, epoch)

    if args.adjust_lr:
        args.lr = adjust_learning_rate(optimizer, args.lr, args.weight_decay,
                                       epoch, args.epochs)

    if args.net == "fcn" or args.net == "psp":
        checkpoint_fn = os.path.join(
            args.pth_dir, "%s-%s-res%s-%s.pth.tar" %
            (args.savename, args.net, args.res, epoch + 1))
    else:
        checkpoint_fn = os.path.join(
            args.pth_dir,
            "%s-%s-%s.pth.tar" % (args.savename, args.net, epoch + 1))

    args.start_epoch = epoch + 1
    save_dic = {
        'args': args,
        'epoch': epoch + 1,
        'g1_state_dict': model_g1.state_dict(),
        'g2_state_dict': model_g2.state_dict(),
        'f1_state_dict': model_f1.state_dict(),
        'optimizer': optimizer.state_dict()
    }

    save_checkpoint(save_dic, is_best=False, filename=checkpoint_fn)
def main():
	global args
	args = parser.parse_args()
	print(args)

	if not os.path.exists(os.path.join(args.save_root,'checkpoint')):
		os.makedirs(os.path.join(args.save_root,'checkpoint'))

	if args.cuda:
		cudnn.benchmark = True

	print('----------- Network Initialization --------------')
	snet = define_tsnet(name=args.s_name, num_class=args.num_class, cuda=args.cuda)
	checkpoint = torch.load(args.s_init)
	load_pretrained_model(snet, checkpoint['net'])

	tnet = define_tsnet(name=args.t_name, num_class=args.num_class, cuda=args.cuda)
	checkpoint = torch.load(args.t_model)
	load_pretrained_model(tnet, checkpoint['net'])
	tnet.eval()
	for param in tnet.parameters():
		param.requires_grad = False
	print('-----------------------------------------------')

	# initialize optimizer
	optimizer = torch.optim.SGD(snet.parameters(),
								lr = args.lr, 
								momentum = args.momentum, 
								weight_decay = args.weight_decay,
								nesterov = True)

	# define loss functions
	if args.cuda:
		criterionCls    = torch.nn.CrossEntropyLoss().cuda()
		criterionFitnet = torch.nn.MSELoss().cuda()
	else:
		criterionCls    = torch.nn.CrossEntropyLoss()
		criterionFitnet = torch.nn.MSELoss()

	# define transforms
	if args.data_name == 'cifar10':
		dataset = dst.CIFAR10
		mean = (0.4914, 0.4822, 0.4465)
		std  = (0.2470, 0.2435, 0.2616)
	elif args.data_name == 'cifar100':
		dataset = dst.CIFAR100
		mean = (0.5071, 0.4865, 0.4409)
		std  = (0.2673, 0.2564, 0.2762)
	else:
		raise Exception('invalid dataset name...')

	train_transform = transforms.Compose([
			transforms.Pad(4, padding_mode='reflect'),
			transforms.RandomCrop(32),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])
	test_transform = transforms.Compose([
			transforms.CenterCrop(32),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])

	# define data loader
	train_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = train_transform,
					train     = True,
					download  = True),
			batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
	test_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = test_transform,
					train     = False,
					download  = True),
			batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

	for epoch in range(1, args.epochs+1):
		epoch_start_time = time.time()

		adjust_lr(optimizer, epoch)

		# train one epoch
		nets = {'snet':snet, 'tnet':tnet}
		criterions = {'criterionCls':criterionCls, 'criterionFitnet':criterionFitnet}
		train(train_loader, nets, optimizer, criterions, epoch)
		epoch_time = time.time() - epoch_start_time
		print('one epoch time is {:02}h{:02}m{:02}s'.format(*transform_time(epoch_time)))

		# evaluate on testing set
		print('testing the models......')
		test_start_time = time.time()
		test(test_loader, nets, criterions)
		test_time = time.time() - test_start_time
		print('testing time is {:02}h{:02}m{:02}s'.format(*transform_time(test_time)))

		# save model
		print('saving models......')
		save_name = 'fitnet_r{}_r{}_{:>03}.ckp'.format(args.t_name[6:], args.s_name[6:], epoch)
		save_name = os.path.join(args.save_root, 'checkpoint', save_name)
		if epoch == 1:
			save_checkpoint({
				'epoch': epoch,
				'snet': snet.state_dict(),
				'tnet': tnet.state_dict(),
			}, save_name)
		else:
			save_checkpoint({
				'epoch': epoch,
				'snet': snet.state_dict(),
			}, save_name)
Exemple #23
0
def main():
    global args, best_EPE
    args = parser.parse_args()
    save_path = "{},{},{}epochs{},b{},lr{}".format(
        args.arch,
        args.solver,
        args.epochs,
        ",epochSize" + str(args.epoch_size) if args.epoch_size > 0 else "",
        args.batch_size,
        args.lr,
    )
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print("=> will save everything to {}".format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, "train"))
    test_writer = SummaryWriter(os.path.join(save_path, "test"))
    output_writers = []
    for i in range(3):
        output_writers.append(
            SummaryWriter(os.path.join(save_path, "test", str(i))))

    # Data loading code
    if args.data_loader == "torch":
        print("Using default data loader \n")
        input_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
            transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1]),
        ])
        target_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0, 0],
                                 std=[args.div_flow, args.div_flow]),
        ])
        test_transform = transforms.Compose([
            transforms.Resize((122, 162)),
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
            transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1]),
        ])

        if "KITTI" in args.dataset:
            args.sparse = True
        if args.sparse:
            co_transform = flow_transforms.Compose([
                flow_transforms.RandomCrop((122, 162)),
                flow_transforms.RandomVerticalFlip(),
                flow_transforms.RandomHorizontalFlip(),
            ])
        else:
            co_transform = flow_transforms.Compose([
                flow_transforms.RandomTranslate(10),
                flow_transforms.RandomRotate(10, 5),
                flow_transforms.RandomCrop((122, 162)),
                flow_transforms.RandomVerticalFlip(),
                flow_transforms.RandomHorizontalFlip(),
            ])

        print("=> fetching img pairs in '{}'".format(args.data))
        train_set, test_set = datasets.__dict__[args.dataset](
            args.data,
            transform=input_transform,
            test_transform=test_transform,
            target_transform=target_transform,
            co_transform=co_transform,
            split=args.split_file if args.split_file else args.split_value,
        )
        print("{} samples found, {} train samples and {} test samples ".format(
            len(test_set) + len(train_set), len(train_set), len(test_set)))
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            num_workers=args.workers,
            pin_memory=True,
            shuffle=True,
        )
        val_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=args.batch_size,
            num_workers=args.workers,
            pin_memory=True,
            shuffle=False,
        )

    if args.data_loader == "dali":
        print("Using NVIDIA DALI \n")
        (
            (image0_train_names, image0_val_names),
            (image1_train_names, image1_val_names),
            (flow_train_names, flow_val_names),
        ) = make_dali_dataset(
            args.data,
            split=args.split_file if args.split_file else args.split_value)
        print("{} samples found, {} train samples and {} test samples ".format(
            len(image0_val_names) + len(image0_train_names),
            len(image0_train_names),
            len(image0_val_names),
        ))
        global train_length
        global val_length
        train_length = len(image0_train_names)
        val_length = len(image0_val_names)

        def create_image_pipeline(
            batch_size,
            num_threads,
            device_id,
            image0_list,
            image1_list,
            flow_list,
            valBool,
        ):
            pipeline = Pipeline(batch_size, num_threads, device_id, seed=2)
            with pipeline:
                if valBool:
                    shuffleBool = False
                else:
                    shuffleBool = True
                """ READ FILES """
                image0, _ = fn.readers.file(
                    file_root=args.data,
                    files=image0_list,
                    random_shuffle=shuffleBool,
                    name="Reader",
                    seed=1,
                )
                image1, _ = fn.readers.file(
                    file_root=args.data,
                    files=image1_list,
                    random_shuffle=shuffleBool,
                    seed=1,
                )
                flo = fn.readers.numpy(
                    file_root=args.data,
                    files=flow_list,
                    random_shuffle=shuffleBool,
                    seed=1,
                )
                """ DECODE AND RESHAPE """
                image0 = fn.decoders.image(image0, device="cpu")
                image0 = fn.reshape(image0, layout="HWC")
                image1 = fn.decoders.image(image1, device="cpu")
                image1 = fn.reshape(image1, layout="HWC")
                images = fn.cat(image0, image1, axis=2)
                flo = fn.reshape(flo, layout="HWC")

                if valBool:
                    images = fn.resize(images, resize_x=162, resize_y=122)
                else:
                    """ CO-TRANSFORM """
                    # random translate
                    # angle_rng = fn.random.uniform(range=(-90, 90))
                    # images = fn.rotate(images, angle=angle_rng, fill_value=0)
                    # flo = fn.rotate(flo, angle=angle_rng, fill_value=0)

                    images = fn.random_resized_crop(
                        images,
                        size=[122, 162],  # 122, 162
                        random_aspect_ratio=[1.3, 1.4],
                        random_area=[0.8, 0.9],
                        seed=1,
                    )
                    flo = fn.random_resized_crop(
                        flo,
                        size=[122, 162],
                        random_aspect_ratio=[1.3, 1.4],
                        random_area=[0.8, 0.9],
                        seed=1,
                    )

                    # coin1 = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=10)
                    # coin1_n = coin1 ^ True
                    # coin2 = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=20)
                    # coin2_n = coin2 ^ True

                    # images = (
                    #     fn.flip(images, horizontal=1, vertical=1) * coin1 * coin2
                    #     + fn.flip(images, horizontal=1) * coin1 * coin2_n
                    #     + fn.flip(images, vertical=1) * coin1_n * coin2
                    #     + images * coin1_n * coin2_n
                    # )
                    # flo = (
                    #     fn.flip(flo, horizontal=1, vertical=1) * coin1 * coin2
                    #     + fn.flip(flo, horizontal=1) * coin1 * coin2_n
                    #     + fn.flip(flo, vertical=1) * coin1_n * coin2
                    #     + flo * coin1_n * coin2_n
                    # )
                    # _flo = flo
                    # flo_0 = fn.slice(_flo, axis_names="C", start=0, shape=1)
                    # flo_1 = fn.slice(_flo, axis_names="C", start=1, shape=1)
                    # flo_0 = flo_0 * coin1 * -1 + flo_0 * coin1_n
                    # flo_1 = flo_1 * coin2 * -1 + flo_1 * coin2_n
                    # # flo  = noflip + vertical flip + horizontal flip + both_flip

                    # # A horizontal flip is around the vertical axis (switch left and right)
                    # # So for a vertical flip coin1 is activated and needs to give +1, coin2 is activated needs to give -1
                    # # for a horizontal flip coin1 is activated and needs to be -1, coin2_n needs +1
                    # # no flip coin coin1_n +1, coin2_n +1

                    # flo = fn.cat(flo_0, flo_1, axis_name="C")
                """ NORMALIZE """
                images = fn.crop_mirror_normalize(
                    images,
                    mean=[0, 0, 0, 0, 0, 0],
                    std=[255, 255, 255, 255, 255, 255])
                images = fn.crop_mirror_normalize(
                    images,
                    mean=[0.45, 0.432, 0.411, 0.45, 0.432, 0.411],
                    std=[1, 1, 1, 1, 1, 1],
                )
                flo = fn.crop_mirror_normalize(
                    flo, mean=[0, 0], std=[args.div_flow, args.div_flow])

                pipeline.set_outputs(images, flo)
            return pipeline

        class DALILoader:
            def __init__(
                self,
                batch_size,
                image0_names,
                image1_names,
                flow_names,
                valBool,
                num_threads,
                device_id,
            ):
                self.pipeline = create_image_pipeline(
                    batch_size,
                    num_threads,
                    device_id,
                    image0_names,
                    image1_names,
                    flow_names,
                    valBool,
                )
                self.pipeline.build()
                self.epoch_size = self.pipeline.epoch_size(
                    "Reader") / batch_size

                output_names = ["images", "flow"]
                if valBool:
                    self.dali_iterator = pytorch.DALIGenericIterator(
                        self.pipeline,
                        output_names,
                        reader_name="Reader",
                        last_batch_policy=pytorch.LastBatchPolicy.PARTIAL,
                        auto_reset=True,
                    )
                else:
                    self.dali_iterator = pytorch.DALIGenericIterator(
                        self.pipeline,
                        output_names,
                        reader_name="Reader",
                        last_batch_policy=pytorch.LastBatchPolicy.PARTIAL,
                        auto_reset=True,
                    )

            def __len__(self):
                return int(self.epoch_size)

            def __iter__(self):
                return self.dali_iterator.__iter__()

            def reset(self):
                return self.dali_iterator.reset()

        train_loader = DALILoader(
            batch_size=args.batch_size,
            num_threads=args.workers,
            device_id=0,
            image0_names=image0_train_names,
            image1_names=image1_train_names,
            flow_names=flow_train_names,
            valBool=False,
        )

        val_loader = DALILoader(
            batch_size=args.batch_size,
            num_threads=args.workers,
            device_id=0,
            image0_names=image0_val_names,
            image1_names=image1_val_names,
            flow_names=flow_val_names,
            valBool=True,
        )

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        args.arch = network_data["arch"]
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](network_data).to(device)

    assert args.solver in ["adam", "sgd"]
    print("=> setting {} solver".format(args.solver))
    param_groups = [
        {
            "params": model.bias_parameters(),
            "weight_decay": args.bias_decay
        },
        {
            "params": model.weight_parameters(),
            "weight_decay": args.weight_decay
        },
    ]

    if device.type == "cuda":
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

    if args.solver == "adam":
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == "sgd":
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.5)

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch

        # # --- quant
        # model.train()
        # model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')  # torch.quantization.default_qconfig
        # # model = torch.quantization.fuse_modules(model, [['Conv2d', 'bn', 'relu']])
        # torch.backends.quantized.engine = 'qnnpack'
        # model = torch.quantization.prepare_qat(model)
        # # --- quant

        # my_sample = next(itertools.islice(train_loader, 10, None))
        # print(my_sample[1][0])
        # print("Maximum value is ", torch.max(my_sample[0][0]))
        # print("Minimum value is ", torch.min(my_sample[0][0]))

        train_loss, train_EPE = train(train_loader, model, optimizer, epoch,
                                      train_writer)
        train_writer.add_scalar("mean EPE", train_EPE, epoch)

        scheduler.step()

        # evaluate on validation set

        with torch.no_grad():
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar("mean EPE", EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        # if is_best:
        #     kernels = model.module.conv3_1[0].weight.data
        #     kernels = kernels.cpu()
        #     kernels = kernels - kernels.min()
        #     kernels = kernels / kernels.max()
        #     img = make_grid(kernels)
        #     plt.imshow(img.permute(1, 2, 0))
        #     plt.show()
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": args.arch,
                "state_dict": model.module.state_dict(),
                "best_EPE": best_EPE,
                "div_flow": args.div_flow,
            },
            is_best,
            save_path,
            model,
            dummy_input,
        )
Exemple #24
0
                loss=class_loss+recon_loss+0.1*(recon_loss2+recon_loss1)
            loss.backward()
            optimizer.step()
        global_t+=1
        running_loss += loss.item()* labels.size(0)
        running_corrects += torch.sum(preds.data.cpu() == atarget.data.cpu())
        if i % 20== 0 and i>0:    
            print('epoch:', epoch, '| global_t:', global_t,'| class_loss:', class_loss.item(), '| kl_loss:', kl_loss.item(),'| recon_loss:', recon_loss.item(),'| recon_loss1:', recon_loss1.item(),'| recon_loss2:', recon_loss2.item(),'| all_loss:',loss.item())            
            print("prediction action:",preds.data.cpu())
            print("ground truth action:",atarget.data.cpu())
            print("ground truth action:",a.data.cpu())
            acc=running_corrects.double() / (i*BATCH_SIZE)
            print("acc:",acc)
 
        if i % 1000== 0:
            util.save_checkpoint({'global_t': global_t,'state_dict': net.state_dict(),'epoch_loss':None,'epoch_act_acc':None,'epoch_sce_acc':None}, global_t) 
            print("testacc:",testacc)    
        del x0,x1,x2,x3,x4,x5, xnext,g,labels
    epoch_loss = running_loss / (i*BATCH_SIZE)
    epoch_acc = running_corrects.double() / (i*BATCH_SIZE)

    print('{} Loss: {:.4f} act_Acc: {:.4f}'.format(
                'train', epoch_loss, epoch_acc))
    util.save_checkpoint({'global_t': global_t,'state_dict': net.state_dict(),'epoch_loss':epoch_loss,'epoch_act_acc':epoch_acc}, global_t)  
    running_corrects = 0
    #==================================for test============================================================================================
    j=0
    for x0,x1,x2,x3,x4,x5, xnext,g,labels,strname,preact,pre_labeld  in test_loader:
        j+=1   
        x0= x0.to(device)
        x1= x1.to(device)
Exemple #25
0
def main():
    np.random.seed(0)
    torch.manual_seed(0)

    logger.info('Loading data...')
    train_loader, val_loader, classes = custom_dataset.load_data(args)

    # override autodetect if n_classes is given
    if args.n_classes > 0:
        classes = np.arange(args.n_classes)

    model = load_model(classes)

    logger.info('Loaded model; params={}'.format(util.count_parameters(model)))
    if not args.cpu:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = "cpu"

    model.to(device)
    cudnn.benchmark = True
    logger.info('Running on ' + str(device))

    summary_writer = Logger(args.logdir)

    # Loss and Optimizer
    n_epochs = args.epochs
    if args.label_smoothing > 0:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    train_state = init_train_state()
    # freeze layers
    for l in args.freeze_layers:
        for p in getattr(model, l).parameters():
            p.requires_grad = False
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=train_state['lr'],
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'nesterov':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=train_state['lr'],
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
    # this is used to warm-start
    if args.warm_start_from:
        logger.info('Warm-starting from {}'.format(args.warm_start_from))
        assert os.path.isfile(args.warm_start_from)
        train_state = load_checkpoint(args.warm_start_from, model, optimizer)
        logger.info('Params loaded.')
        # do not override train_state these when warm staring
        train_state = init_train_state()

    ckptfile = str(Path(args.logdir) / args.latest_fname)
    if os.path.isfile(ckptfile):
        logger.info('Loading checkpoint: {}'.format(ckptfile))
        train_state = load_checkpoint(ckptfile, model, optimizer)
        logger.info('Params loaded.')
    else:
        logger.info('Checkpoint {} not found; ignoring.'.format(ckptfile))

    # Training / Eval loop
    epoch_time = []                 # store time per epoch
    # we save epoch+1 to checkpoints; but for eval we should repeat prev. epoch
    if args.skip_train:
        train_state['start_epoch'] -= 1
    for epoch in range(train_state['start_epoch'], n_epochs):
        logger.info('Epoch: [%d/%d]' % (epoch + 1, n_epochs))
        start = time.time()

        if not args.skip_train:
            model.train()
            train(train_loader, device, model, criterion, optimizer, summary_writer, train_state,
                  n_classes=len(classes))
            logger.info('Time taken: %.2f sec...' % (time.time() - start))
            if epoch == 0:
                train_state['steps_epoch'] = train_state['step']
        # always eval on last epoch
        if not args.skip_eval or epoch == n_epochs - 1:
            logger.info('\n Starting evaluation...')
            model.eval()
            eval_shrec = True if epoch == n_epochs - 1 and args.retrieval_dir else False
            metrics, inputs = eval(
                val_loader, device, model, criterion, eval_shrec)

            logger.info('\tcombined: %.2f, Acc: %.2f, mAP: %.2f, Loss: %.4f' %
                        (metrics['combined'],
                         metrics['acc_inst'],
                         metrics.get('mAP_inst', 0.),
                         metrics['loss']))

            # Log epoch to tensorboard
            # See log using: tensorboard --logdir='logs' --port=6006
            ims = get_summary_ims(inputs)
            if not args.nolog:
                util.logEpoch(summary_writer, model, epoch + 1, metrics, ims)
        else:
            metrics = None

        # Decaying Learning Rate
        if args.lr_decay_mode == 'step':
            if (epoch + 1) % args.lr_decay_freq == 0:
                train_state['lr'] *= args.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = train_state['lr']

        # Save model
        if not args.skip_train:
            logger.info('\tSaving latest model')
            util.save_checkpoint({
                'epoch': epoch + 1,
                'step': train_state['step'],
                'steps_epoch': train_state['steps_epoch'],
                'state_dict': model.state_dict(),
                'metrics': metrics,
                'optimizer': optimizer.state_dict(),
                'lr': train_state['lr'],
            },
                str(Path(args.logdir) / args.latest_fname))

        total_epoch_time = time.time() - start
        epoch_time.append(total_epoch_time)
        logger.info('Total time for this epoch: {} s'.format(total_epoch_time))

        # if last epoch, show eval results
        if epoch == n_epochs - 1:
            logger.info(
                '|model|combined|acc inst|acc cls|mAP inst|mAP cls|loss|')
            logger.info('|{}|{:.2f}|{:.2f}|{:.2f}|{:.2f}|{:.2f}|{:.4f}|'
                        .format(os.path.basename(args.logdir),
                                metrics['combined'],
                                metrics['acc_inst'],
                                metrics['acc_cls'],
                                metrics.get('mAP_inst', 0.),
                                metrics.get('mAP_cls', 0.),
                                metrics['loss']))

        if args.skip_train:
            # if evaluating, run it once
            break

        if time.perf_counter() + np.max(epoch_time) > start_time + args.exit_after:
            logger.info('Next epoch will likely exceed alotted time; exiting...')
            break
            print('Epoch [%d/%d], Step[%d/%d], loss: %f, l1: %f, lap: %f'
                  % (epoch + 1, n_epochs, i + 1, total_step, loss.data[0], l1_loss.data[0],
                     lap_loss.data[0]
                     ))

        # save the real images
        if (i + 1) == sample_step:
            torchvision.utils.save_image(denorm(x.data),
                                         os.path.join(sample_path,
                                                      'real_samples-%d-%d.png' % (
                                                          epoch + 1, i + 1)), nrow=4)
        # save the generated images
        if (i + 1) % sample_step == 0:
            torchvision.utils.save_image(denorm(x_hat.data),
                                         os.path.join(sample_path,
                                                      'fake_samples-%d-%d.png' % (
                                                          epoch + 1, i + 1)), nrow=4)

    if (epoch + 1) % opt.ckpt_step == 0:
        print("saving checkpoint ..")
        checkpoint_path = os.path.join(model_path, 'checkpoint_%d.pth.tar' % (epoch + 1))

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': generator.state_dict(),
            'latent': learnable_z,
            'optimizer': g_optimizer.state_dict(),
            'args': opt
        }, filename=checkpoint_path)
        print("done.")
Exemple #27
0
def main():
    global args, best_EPE, image_resize, event_interval, spiking_ts, device, sp_threshold
    save_path = '{},{},{}epochs{},b{},lr{}'.format(
        args.arch, args.solver, args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size, args.lr)
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.savedir, save_path)
    print('=> Everything will be saved to {}'.format(save_path))

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, 'train'))
    test_writer = SummaryWriter(os.path.join(save_path, 'test'))
    output_writers = []
    for i in range(3):
        output_writers.append(
            SummaryWriter(os.path.join(save_path, 'test', str(i))))

    # Data loading code
    co_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomVerticalFlip(0.5),
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop((256, 256),
                                     scale=(0.5, 1.0),
                                     ratio=(0.75, 1.3333333333333333),
                                     interpolation=2),
        transforms.ToTensor(),
    ])
    Test_dataset = Test_loading()
    test_loader = DataLoader(dataset=Test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.workers)

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        #args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](network_data).cuda()
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{
        'params': model.module.bias_parameters(),
        'weight_decay': args.bias_decay
    }, {
        'params': model.module.weight_parameters(),
        'weight_decay': args.weight_decay
    }]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    if args.evaluate:
        with torch.no_grad():
            best_EPE = validate(test_loader, model, -1, output_writers)
        return

    Train_dataset = Train_loading(transform=co_transform)
    train_loader = DataLoader(dataset=Train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.7)

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()

        # train for one epoch
        train_loss = train(train_loader, model, optimizer, epoch, train_writer)
        train_writer.add_scalar('mean loss', train_loss, epoch)

        # Test at every 5 epoch during training
        if (epoch + 1) % args.evaluate_interval == 0:
            # evaluate on validation set
            with torch.no_grad():
                EPE = validate(test_loader, model, epoch, output_writers)
            test_writer.add_scalar('mean EPE', EPE, epoch)

            if best_EPE < 0:
                best_EPE = EPE

            is_best = EPE < best_EPE
            best_EPE = min(EPE, best_EPE)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.module.state_dict(),
                    'best_EPE': best_EPE,
                    'div_flow': args.div_flow
                }, is_best, save_path)
Exemple #28
0
def main():
    script_dir = Path.cwd()
    args = util.get_config(default_file=script_dir / 'config.yaml')

    output_dir = script_dir / args.output_dir
    output_dir.mkdir(exist_ok=True)

    log_dir = util.init_logger(args.name, output_dir,
                               script_dir / 'logging.conf')
    logger = logging.getLogger()

    with open(log_dir / "args.yaml",
              "w") as yaml_file:  # dump experiment config
        yaml.safe_dump(args, yaml_file)

    pymonitor = util.ProgressMonitor(logger)
    tbmonitor = util.TensorBoardMonitor(logger, log_dir)
    monitors = [pymonitor, tbmonitor]

    if args.device.type == 'cpu' or not t.cuda.is_available(
    ) or args.device.gpu == []:
        args.device.gpu = []
    else:
        available_gpu = t.cuda.device_count()
        for dev_id in args.device.gpu:
            if dev_id >= available_gpu:
                logger.error(
                    'GPU device ID {0} requested, but only {1} devices available'
                    .format(dev_id, available_gpu))
                exit(1)
        # Set default device in case the first one on the list
        t.cuda.set_device(args.device.gpu[0])
        # Enable the cudnn built-in auto-tuner to accelerating training, but it
        # will introduce some fluctuations in a narrow range.
        t.backends.cudnn.benchmark = True
        t.backends.cudnn.deterministic = False

    # Initialize data loader
    train_loader, val_loader, test_loader = util.load_data(args.dataloader)
    logger.info('Dataset `%s` size:' % args.dataloader.dataset +
                '\n          Training Set = %d (%d)' %
                (len(train_loader.sampler), len(train_loader)) +
                '\n        Validation Set = %d (%d)' %
                (len(val_loader.sampler), len(val_loader)) +
                '\n              Test Set = %d (%d)' %
                (len(test_loader.sampler), len(test_loader)))

    # Create the model
    model = create_model(args)
    modules_to_replace = quan.find_modules_to_quantize(model, args.quan)
    model = quan.replace_module_by_names(model, modules_to_replace)
    tbmonitor.writer.add_graph(
        model, input_to_model=train_loader.dataset[0][0].unsqueeze(0))
    logger.info('Inserted quantizers into the original model')

    if args.device.gpu and not args.dataloader.serialized:
        model = t.nn.DataParallel(model, device_ids=args.device.gpu)
    model.to(args.device.type)

    start_epoch = 0
    if args.resume.path:
        model, start_epoch, _ = util.load_checkpoint(model,
                                                     args.resume.path,
                                                     args.device.type,
                                                     lean=args.resume.lean)

    # Define loss function (criterion) and optimizer
    criterion = t.nn.CrossEntropyLoss().to(args.device.type)

    # optimizer = t.optim.Adam(model.parameters(), lr=args.optimizer.learning_rate)
    optimizer = t.optim.SGD(model.parameters(),
                            lr=args.optimizer.learning_rate,
                            momentum=args.optimizer.momentum,
                            weight_decay=args.optimizer.weight_decay)
    lr_scheduler = util.lr_scheduler(optimizer,
                                     batch_size=train_loader.batch_size,
                                     num_samples=len(train_loader.sampler),
                                     **args.lr_scheduler)
    logger.info(('Optimizer: %s' % optimizer).replace('\n', '\n' + ' ' * 11))
    logger.info('LR scheduler: %s\n' % lr_scheduler)

    perf_scoreboard = process.PerformanceScoreboard(args.log.num_best_scores)

    if args.eval:
        process.validate(test_loader, model, criterion, -1, monitors, args)
    else:  # training
        if args.resume.path or args.pre_trained:
            logger.info('>>>>>>>> Epoch -1 (pre-trained model evaluation)')
            top1, top5, _ = process.validate(val_loader, model, criterion,
                                             start_epoch - 1, monitors, args)
            perf_scoreboard.update(top1, top5, start_epoch - 1)
        for epoch in range(start_epoch, args.epochs):
            logger.info('>>>>>>>> Epoch %3d' % epoch)
            t_top1, t_top5, t_loss = process.train(train_loader, model,
                                                   criterion, optimizer,
                                                   lr_scheduler, epoch,
                                                   monitors, args)
            v_top1, v_top5, v_loss = process.validate(val_loader, model,
                                                      criterion, epoch,
                                                      monitors, args)

            tbmonitor.writer.add_scalars('Train_vs_Validation/Loss', {
                'train': t_loss,
                'val': v_loss
            }, epoch)
            tbmonitor.writer.add_scalars('Train_vs_Validation/Top1', {
                'train': t_top1,
                'val': v_top1
            }, epoch)
            tbmonitor.writer.add_scalars('Train_vs_Validation/Top5', {
                'train': t_top5,
                'val': v_top5
            }, epoch)

            perf_scoreboard.update(v_top1, v_top5, epoch)
            is_best = perf_scoreboard.is_best(epoch)
            util.save_checkpoint(epoch, args.arch, model, {
                'top1': v_top1,
                'top5': v_top5
            }, is_best, args.name, log_dir)

        logger.info('>>>>>>>> Epoch -1 (final model evaluation)')
        process.validate(test_loader, model, criterion, -1, monitors, args)

    tbmonitor.writer.close()  # close the TensorBoard
    logger.info('Program completed successfully ... exiting ...')
    logger.info(
        'If you have any questions or suggestions, please visit: github.com/zhutmost/lsq-net'
    )
Exemple #29
0
def main():
    global args, best_EPE
    args = parser.parse_args()
    
    if not args.data:
        f = open('train_src/data_loc.json', 'r')
        content = f.read()
        f.close()
        data_loc = json.loads(content)
        args.data = data_loc[args.dataset]
    
    if not args.savpath:
        save_path = '{},{},{}epochs{},b{},lr{}'.format(
            args.arch,
            args.solver,
            args.epochs,
            ',epochSize'+str(args.epoch_size) if args.epoch_size > 0 else '',
            args.batch_size,
            args.lr)
        if not args.no_date:
            timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
            save_path = os.path.join(timestamp,save_path)
    else:
        save_path = args.savpath
    save_path = os.path.join(args.dataset,save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # save training args
    save_training_args(save_path, args)

    train_writer = SummaryWriter(os.path.join(save_path,'train'))
    test_writer = SummaryWriter(os.path.join(save_path,'test'))
    output_writers = []
    for i in range(3):
        output_writers.append(SummaryWriter(os.path.join(save_path,'test',str(i))))

    # Data loading code
    if args.grayscale:
        input_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Grayscale(num_output_channels=3),
            transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
            transforms.Normalize(mean=[0.431,0.431,0.431], std=[1,1,1]) # 0.431=(0.45+0.432+0.411)/3
            # transforms.Normalize(mean=[0.5,0.5,0.5], std=[1,1,1])
        ])
    else:
        input_transform = transforms.Compose([
            flow_transforms.ArrayToTensor(),
            transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
            transforms.Normalize(mean=[0.45,0.432,0.411], std=[1,1,1])
        ])

    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0,0],std=[args.div_flow,args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        args.sparse = True
    if args.sparse:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320,448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10,5),
            flow_transforms.RandomCrop((320,448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value
    )
    print('{} samp-les found, {} train samples and {} test samples '.format(len(test_set)+len(train_set),
                                                                           len(train_set),
                                                                           len(test_set)))
    if not args.evaluate:
        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=args.batch_size,
            num_workers=args.workers, pin_memory=True, shuffle=True)
    
    val_loader = torch.utils.data.DataLoader(
        test_set, batch_size=args.batch_size,
        num_workers=args.workers, pin_memory=True, shuffle=False)

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        # args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    # if (args.qw and args.qa and args.cut_ratio) is not None:
    #     model = models.__dict__[args.arch](data=network_data, bitW=args.qw, bitA=args.qa, cut_ratio=args.cut_ratio).cuda()
    # elif (args.qw and args.qa) is not None:
    #     model = models.__dict__[args.arch](data=network_data, bitW=args.qw, bitA=args.qa).cuda()
    # else:
    #     model = models.__dict__[args.arch](data=network_data).cuda()

    model = models.__dict__[args.arch](data=network_data, args=args).to(device)

    # model = torch.nn.DataParallel(model).cuda()
    # cudnn.benchmark = True

    assert(args.solver in ['adam', 'sgd', 'adamw'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{'params': model.bias_parameters(), 'weight_decay': args.bias_decay},
                    {'params': model.weight_parameters(), 'weight_decay': args.weight_decay}]

    if device.type == "cuda":
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups, args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups, args.lr,
                                    momentum=args.momentum)
    elif args.solver == 'adamw':
        optimizer = torch.optim.AdamW(param_groups, args.lr,
                                    betas=(args.momentum, args.beta))
    
    if args.print_model:
        exportpars(model, save_path, args)
        exportsummary(model, save_path, args)
        if args.savpath == 'test':
            return

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    if args.demo:
        demo(val_loader, model, 0, output_writers)
        return
    if args.demovideo:
        demovideo(val_loader, model, 0, output_writers)
        return

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)

    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch, train_writer)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)
        scheduler.step()

        # evaluate on validation set
        with torch.no_grad():
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar('mean EPE', EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.module.state_dict(),
            'best_EPE': best_EPE,
            'div_flow': args.div_flow
        }, is_best, save_path)
def train_MVCNN(case_description):
    print("\nTrain MVCNN\n")
    MVCNN = 'mvcnn'
    RESNET = 'resnet'
    MODELS = [RESNET, MVCNN]

    DATA_PATH = globals.DATA_PATH
    DEPTH = None
    MODEL = MODELS[1]
    EPOCHS = 100
    BATCH_SIZE = 10
    LR = 0.0001
    MOMENTUM = 0.9
    LR_DECAY_FREQ = 30
    LR_DECAY = 0.1
    PRINT_FREQ = 10
    RESUME = ""
    PRETRAINED = True

    REMOTE = globals.REMOTE
    case_description = case_description

    print('Loading data')

    transform = transforms.Compose([
        transforms.CenterCrop(500),
        transforms.Resize(224),
        transforms.ToTensor(),
    ])

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load dataset
    dset_train = MultiViewDataSet(DATA_PATH, 'train', transform=transform)
    train_loader = DataLoader(dset_train,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=2)

    ## Got rid of this, not using validation here.
    # dset_val = MultiViewDataSet(DATA_PATH, 'test', transform=transform)
    # val_loader = DataLoader(dset_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    classes = dset_train.classes
    print(len(classes), classes)

    if MODEL == RESNET:
        if DEPTH == 18:
            model = resnet18(pretrained=PRETRAINED, num_classes=len(classes))
        elif DEPTH == 34:
            model = resnet34(pretrained=PRETRAINED, num_classes=len(classes))
        elif DEPTH == 50:
            model = resnet50(pretrained=PRETRAINED, num_classes=len(classes))
        elif DEPTH == 101:
            model = resnet101(pretrained=PRETRAINED, num_classes=len(classes))
        elif DEPTH == 152:
            model = resnet152(pretrained=PRETRAINED, num_classes=len(classes))
        else:
            raise Exception(
                'Specify number of layers for resnet in command line. --resnet N'
            )
        print('Using ' + MODEL + str(DEPTH))
    else:
        # number of ModelNet40 needs to match loaded pre-trained model
        model = mvcnn(pretrained=PRETRAINED, num_classes=40)
        print('Using ' + MODEL)

    cudnn.benchmark = True

    print('Running on ' + str(device))
    """
    Load pre-trained model and freeze weights for training.
    This is done by setting param.requires_grad to False
    """
    """Just added this check to load my pretrained model instead of copying it to the repo and having a duplicate"""
    if REMOTE:
        PATH = "../../MVCNN_Peter/checkpoint/mvcnn18_checkpoint.pth.tar"
    else:
        PATH = "checkpoint/model_from_pete.tar"

    loaded_model = torch.load(PATH)
    model.load_state_dict(loaded_model['state_dict'])
    for param in model.parameters():
        param.requires_grad = False
    num_ftrs = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(num_ftrs, len(classes))

    model.to(device)

    print(model)

    logger = Logger('logs')

    # Loss and Optimizer
    lr = LR
    n_epochs = EPOCHS
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_acc = 0.0
    start_epoch = 0

    # Helper functions
    def load_checkpoint():
        global best_acc, start_epoch
        # Load checkpoint.
        print('\n==> Loading checkpoint..')
        assert os.path.isfile(RESUME), 'Error: no checkpoint file found!'

        checkpoint = torch.load(RESUME)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    def train():
        train_size = len(train_loader)
        loss = None
        total = 0
        correct = 0
        for i, (inputs, targets) in enumerate(train_loader):

            # Convert from list of 3D to 4D
            inputs = np.stack(inputs, axis=1)

            inputs = torch.from_numpy(inputs)

            inputs, targets = inputs.cuda(device), targets.cuda(device)
            inputs, targets = Variable(inputs), Variable(targets)

            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted.cpu() == targets.cpu()).sum().item()
            """
            print("total: ", total)
            print("correct: ", correct)
            print()
            """

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % PRINT_FREQ == 0:
                print("\tIter [%d/%d] Loss: %.4f" %
                      (i + 1, train_size, loss.item()))

        return loss, int(float(float(correct) / float(total)) * 100)

    # Training / Eval loop
    if RESUME:
        load_checkpoint()

    best_acc = 0
    best_loss = 0
    loss_values = []
    acc_values = []
    for epoch in range(start_epoch, n_epochs):
        print('\n-----------------------------------')
        print('Epoch: [%d/%d]' % (epoch + 1, n_epochs))
        start = time.time()

        model.train()
        (t_loss, t_acc) = train()
        loss_values.append(t_loss)
        acc_values.append(t_acc)

        print("Total loss: " + str(t_loss))
        print("Accuracy: " + str(t_acc) + "%")

        print('Time taken: %.2f sec.' % (time.time() - start))

        if t_acc > best_acc:
            print("UPDATE")
            print("UPDATE")
            print("UPDATE")
            print("UPDATE")
            print("UPDATE")
            best_acc = t_acc
            best_loss = t_loss
            util.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'loss_per_epoch': loss_values,
                    'acc_per_epoch': acc_values,
                    'optimizer': optimizer.state_dict(),
                }, MODEL, DEPTH, case_description)

        # Decaying Learning Rate
        if (epoch + 1) % LR_DECAY_FREQ == 0:
            lr *= LR_DECAY
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            print('Learning rate:', lr)

    fig, axs = plt.subplots(2)
    fig.suptitle('Vertically stacked subplots')
    axs[0].plot(loss_values, 'r')
    axs[1].plot(acc_values, 'b')

    if not REMOTE:
        plt.show()
    else:
        plt.savefig("plots/training.png")