def main():

    args = init_argparse()

    def check_gpu(gpu_arg):
        #If gpu_arg is false then simply return the cpu device
        if not gpu_arg:
            return torch.device("cpu")

    #If gpu_arg then make sure to check for CUDA before assigning it
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if device == "cpu":
            print("CUDA was not found on device, using CPU instead.")
        return device

    #device = torch.device('cuda' if torch.cuda.is_available() and args.gpu==True else 'cpu')


#load initial model
    model = network.initial_model(args.arch)
    print('load initial model..')

    #load dataset:
    train_data, valid_data, test_data, train_loader, valid_loader, test_loader = utility.pre_data(
        args.data_dir)
    print('load dataset..')

    #build model
    model, criterion, optimizer, classifier = network.model_setup(
        model, args.input_size, args.arch, args.gpu, args.hidden_units,
        args.learning_rate)
    print('build model arch..')

    #train network
    network.network_training(criterion, optimizer, train_loader, valid_loader,
                             model, args.epochs, args.device)
    print('Network Training..')

    #test network
    network.network_test(model, criterion, test_loader, args.device)
    print('Test Network..')

    #save model
    utility.save_checkpoint(train_data, model, optimizer, args.save_dir)
    print('save model..')
Пример #2
0
def main():
    user_args = get_args()

    class_labels, train_data, test_data, valid_data = utility.load_img(user_args.data_dir)
    model = utility.load_pretrained_model(user_args.arch, user_args.hidden_units)

    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.classifier.parameters(), lr=user_args.learning_rate)
    utility.train(model, user_args.learning_rate, criterion, train_data, valid_data, user_args.epochs, user_args.gpu)
    utility.test(model, test_data, user_args.gpu)
    model.to('cpu')

    # Save Checkpoint for predection
    utility.save_checkpoint({
                    'arch': user_args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'hidden_units': user_args.hidden_units,
                    'class_labels': class_labels
                }, user_args.save_dir)
    print('Saved checkpoint!')
Пример #3
0
def train(args, logger, device_ids):
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
            # model = convert_model(model)
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                                               num_workers=16, pin_memory=True, drop_last=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 
                                               num_workers=16, pin_memory=True, drop_last=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"]
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
        best_alpha_loss = ckpt["best_alpha_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')
        best_alpha_loss = float('inf')

    max_iter = 43100 * (1 - args.valid_portion / 100) / args.batch_size * args.epochs
    tensorboard_iter = cur_iter * (args.batch_size / 16)

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (_, inputs, gts) in enumerate(train_loader):
            # cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.000001, 
            #                                decay_iters=args.decay_iters, decay_power=0.8, power=0.5)
            cur_lr = lr_scheduler(optimizer=optimizer, init_lr=args.lr, cur_iter=cur_iter, max_iter=max_iter, 
                                  max_decay_times=30, decay_rate=0.9)
            
            # img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            inputs = inputs.to(device)
            gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
            gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)

            L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                        pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                        log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

            sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp(log_sigma_a_sqr.mean() / 2)

            optimizer.zero_grad()
            L_overall.backward()
            clip_gradient(optimizer, 5)
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                            .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, tensorboard_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, tensorboard_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, tensorboard_iter)
                writer.add_scalar("other/sigma_t", sigma_t.item(), tensorboard_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), tensorboard_iter)
                writer.add_scalar("other/lr", cur_lr, tensorboard_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()
                
            cur_iter += 1
            tensorboard_iter = cur_iter * (args.batch_size / 16)

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (display_rgb, inputs, gts) in enumerate(valid_loader):
                inputs = inputs.to(device) # [bs, 4, 320, 320]
                gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
                gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                            pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, 
                                                            log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr)

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    input_rbg = torchvision.utils.make_grid(display_rgb, normalize=False, scale_each=True)
                    writer.add_image('input/rbg_image', input_rbg, tensorboard_iter)

                    input_trimap = inputs[:, 3, :, :].unsqueeze(dim=1)
                    input_trimap = torchvision.utils.make_grid(input_trimap, normalize=False, scale_each=True)
                    writer.add_image('input/trimap', input_trimap, tensorboard_iter)

                    output_alpha = alpha_estimation.clone()
                    output_alpha[t_argmax.unsqueeze(dim=1) == 0] = 0.0
                    output_alpha[t_argmax.unsqueeze(dim=1) == 2] = 1.0
                    output_alpha = torchvision.utils.make_grid(output_alpha, normalize=False, scale_each=True)
                    writer.add_image('output/alpha', output_alpha, tensorboard_iter)

                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('pred/trimap_adaptation', trimap_adaption_res, tensorboard_iter)

                    alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=False, scale_each=True)
                    writer.add_image('pred/alpha_estimation', alpha_estimation_res, tensorboard_iter)

                    gt_alpha = gt_alpha
                    gt_alpha = torchvision.utils.make_grid(gt_alpha, normalize=False, scale_each=True)
                    writer.add_image('gt/alpha', gt_alpha, tensorboard_iter)

                    gt_trimap = (gt_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1)
                    gt_trimap = torchvision.utils.make_grid(gt_trimap, normalize=False, scale_each=True)
                    writer.add_image('gt/trimap', gt_trimap, tensorboard_iter)
                    
                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, tensorboard_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, tensorboard_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        is_alpha_best = avg_l_a.avg < best_alpha_loss
        best_alpha_loss = min(avg_l_a.avg, best_alpha_loss)
        if is_best or is_alpha_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, is_alpha_best=is_alpha_best, logger=logger, model=model, optimizer=optimizer, 
                            epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss, best_alpha_loss=best_alpha_loss)

    writer.close()
Пример #4
0
def train(model, optimizer, device, args, logger, multi_gpu):
    torch.manual_seed(7)
    writer = SummaryWriter()

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, 'train')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                                               num_workers=16, pin_memory=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, 'valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 
                                               num_workers=16, pin_memory=True)

    if args.resume:
        logger.info("Start training from saved ckpt")
        ckpt = torch.load(args.ckpt_path)
        model = ckpt["model"].module
        model = model.to(device)
        optimizer = ckpt["optimizer"]

        start_epoch = ckpt["epoch"] + 1
        max_iter = ckpt["max_iter"]
        cur_iter = ckpt["cur_iter"]
        init_lr = ckpt["init_lr"]
        best_loss = ckpt["best_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        max_iter = 43100 * (1 - args.valid_portion) / args.batch_size * args.epochs
        cur_iter = 0
        init_lr = args.lr
        best_loss = float('inf')
    
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (img, gt) in enumerate(train_loader):
            cur_lr = poly_lr_scheduler(optimizer=optimizer, init_lr=init_lr, iter=cur_iter, max_iter=max_iter)

            img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
            gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
            gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation = model(img)
            L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                        pred_alpha=alpha_estimation, gt_trimap=gt_trimap, 
                                                        gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr)
            # if multi_gpu:
            #     L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
            optimizer.zero_grad()
            L_overall.backward()
            optimizer.step()

            if cur_iter % 10 == 0:
                logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                            .format(epoch, index, len(train_loader), L_overall.item(), L_t.item(), L_a.item()))
                writer.add_scalar("loss/L_overall", L_overall.item(), cur_iter)
                writer.add_scalar("loss/L_t", L_t.item(), cur_iter)
                writer.add_scalar("loss/L_a", L_a.item(), cur_iter)
                sigma_t = torch.exp(model.log_sigma_t_sqr / 2)
                sigma_a = torch.exp(model.log_sigma_a_sqr / 2)
                writer.add_scalar("sigma/sigma_t", sigma_t, cur_iter)
                writer.add_scalar("sigma/sigma_a", sigma_a, cur_iter)
                writer.add_scalar("lr", cur_lr, cur_iter)
            
            cur_iter += 1
        
        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (img, gt) in enumerate(valid_loader):
                img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320]
                gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320]
                gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation = model(img)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, 
                                                            pred_alpha=alpha_estimation, gt_trimap=gt_trimap, 
                                                            gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr)
                # if multi_gpu:
                #     L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    trimap_adaption_res = torchvision.utils.make_grid(t_argmax.type(torch.FloatTensor) / 2, normalize=True, scale_each=True)
                    writer.add_image('valid_image/trimap_adaptation', trimap_adaption_res, cur_iter)
                    alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=True, scale_each=True)
                    writer.add_image('valid_image/alpha_estimation', alpha_estimation_res, cur_iter)
                
                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        if is_best or (args.save_ckpt and epoch % 10 == 0):
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            logger.info("Checkpoint saved")
            if (is_best):
                logger.info("Best checkpoint saved")
            save_checkpoint(epoch, model, optimizer, cur_iter, max_iter, init_lr, avg_loss.avg, is_best, args.ckpt_path)

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Пример #5
0
def main_worker(gpu, ngpus_per_node, args):
    global best_prec1, sample_size
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
        print("Current Device is ", torch.cuda.get_device_name(0))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model2:
    if args.pretrained:
        print("=> Model (date_diff): using pre-trained model '{}_{}'".format(
            args.model, args.model_depth))
        pretrained_model = models.__dict__[args.arch](pretrained=True)
    else:
        if args.model_type == 2:
            print("=> Model (date_diff regression): creating model '{}_{}'".
                  format(args.model, args.model_depth))
            pretrained_model = generate_model(args)  # good for resnet
            save_folder = "{}/Model/{}{}".format(args.ROOT, args.model,
                                                 args.model_depth)

    model = longi_models.ResNet_interval(pretrained_model,
                                         args.num_date_diff_classes,
                                         args.num_reg_labels)

    criterion0 = torch.nn.CrossEntropyLoss().cuda(args.gpu)  # for STO loss
    criterion1 = torch.nn.CrossEntropyLoss().cuda(args.gpu)  # for RISI loss

    criterion = [criterion0, criterion1]
    start_epoch = 0

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=0,
                                 amsgrad=False)

    # all models optionally resume from a checkpoint
    if args.resume_all:
        if os.path.isfile(args.resume_all):
            print("=> Model_all: loading checkpoint '{}'".format(
                args.resume_all))
            checkpoint = torch.load(args.resume_all,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()
            start_epoch = checkpoint['epoch']
            print("=> Model_all: loaded checkpoint '{}' (epoch {})".format(
                args.resume_all, checkpoint['epoch']))

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    print("batch-size = ", args.batch_size)
    print("epochs = ", args.epochs)
    print("range-weight (weight of range loss) = ", args.range_weight)
    cudnn.benchmark = True
    print(model)

    # Data loading code
    traingroup = ["train"]
    evalgroup = ["eval"]
    testgroup = ["test"]

    train_augment = ['normalize', 'flip', 'crop']  # 'rotate',
    test_augment = ['normalize', 'crop']
    eval_augment = ['normalize', 'crop']

    train_stages = args.train_stages.strip('[]').split(', ')
    test_stages = args.test_stages.strip('[]').split(', ')
    eval_stages = args.eval_stages.strip('[]').split(', ')
    #############################################################################
    # test-retest analysis

    trt_stages = args.trt_stages.strip('[]').split(', ')

    model_pair = longi_models.ResNet_pair(model.modelA,
                                          args.num_date_diff_classes)
    torch.cuda.set_device(args.gpu)
    model_pair = model_pair.cuda(args.gpu)

    if args.resume_all:
        model_name = args.resume_all[:-8]

    else:
        model_name = save_folder + "_" + time.strftime("%Y-%m-%d_%H-%M")+ \
                     traingroup[0] + '_' + args.train_stages.strip('[]').replace(', ', '')

    data_name = args.datapath.split("/")[-1]

    log_name = (args.ROOT + "/log/" + args.model + str(args.model_depth) +
                "/" + data_name + "/" + time.strftime("%Y-%m-%d_%H-%M"))
    writer = SummaryWriter(log_name)

    trt_dataset = long.LongitudinalDataset3DPair(
        args.datapath, testgroup, args.datapath + "/test_retest_list.csv",
        trt_stages, test_augment, args.max_angle, args.rotate_prob,
        sample_size)

    trt_loader = torch.utils.data.DataLoader(trt_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("\nEvaluation on Test-Retest Set: ")

    util.validate_pair(trt_loader, model_pair, criterion,
                       model_name + "_test_retest", args.epochs, writer,
                       args.print_freq)

    ##########################################################################

    train_dataset = long.LongitudinalDataset3D(
        args.datapath,
        traingroup,
        args.datapath + "/train_list.csv",
        train_stages,
        train_augment,  # advanced transformation: add random rotation
        args.max_angle,
        args.rotate_prob,
        sample_size)

    eval_dataset = long.LongitudinalDataset3D(args.datapath, evalgroup,
                                              args.datapath + "/eval_list.csv",
                                              eval_stages, eval_augment,
                                              args.max_angle, args.rotate_prob,
                                              sample_size)

    test_dataset = long.LongitudinalDataset3D(args.datapath, testgroup,
                                              args.datapath + "/test_list.csv",
                                              test_stages, test_augment,
                                              args.max_angle, args.rotate_prob,
                                              sample_size)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        # sampler = train_sampler,
        num_workers=args.workers,
        pin_memory=True)

    eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.workers,
                                              pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_name = args.datapath.split("/")[-1]

    if args.resume_all:
        model_name = args.resume_all[:-8]

    else:
        model_name = save_folder + "_" + time.strftime("%Y-%m-%d_%H-%M")+ \
                     traingroup[0] + '_' + args.train_stages.strip('[]').replace(', ', '')

    # Use a tool at comet.com to keep track of parameters used
    # log model name, loss, and optimizer as well
    hyper_params["loss"] = criterion
    hyper_params["optimizer"] = optimizer
    hyper_params["model_name"] = model_name
    hyper_params["save_folder"] = save_folder
    experiment.log_parameters(hyper_params)
    # End of using comet

    log_name = (args.ROOT + "/log/" + args.model + str(args.model_depth) +
                "/" + data_name + "/" + time.strftime("%Y-%m-%d_%H-%M"))
    writer = SummaryWriter(log_name)

    if args.evaluate:
        print("\nEVALUATE before starting training: ")
        util.validate(eval_loader,
                      model,
                      criterion,
                      model_name + "_eval",
                      writer=writer,
                      range_weight=args.range_weight)

    # training the model
    if start_epoch < args.epochs - 1:
        print("\nTRAIN: ")
        for epoch in range(start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)
            util.adjust_learning_rate(optimizer, epoch, args.lr)

            # train for one epoch
            util.train(train_loader,
                       model,
                       criterion,
                       optimizer,
                       epoch,
                       sample_size,
                       args.print_freq,
                       writer,
                       range_weight=args.range_weight)

            # evaluate on validation set
            if epoch % args.eval_freq == 0:
                csv_name = model_name + "_eval.csv"
                if os.path.isfile(csv_name):
                    os.remove(csv_name)
                prec = util.validate(eval_loader,
                                     model,
                                     criterion,
                                     model_name + "_eval",
                                     epoch,
                                     writer,
                                     range_weight=args.range_weight)

                if args.early_stop:

                    early_stopping = util.EarlyStopping(
                        patience=args.patience, tolerance=args.tolerance)

                    early_stopping(
                        {
                            'epoch': epoch + 1,
                            'arch1': args.arch1,
                            'arch2': args.model2 + str(args.model2_depth),
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                        }, prec, model_name)

                    print("=" * 50)

                    if early_stopping.early_stop:
                        print("Early stopping at epoch", epoch, ".")
                        break

                else:
                    # remember best prec@1 and save checkpoint
                    is_best = prec > best_prec1
                    best_prec1 = max(prec, best_prec1)
                    util.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'arch': args.model + str(args.model_depth),
                            'state_dict': model.state_dict(),
                            'best_prec1': best_prec1,
                            'optimizer': optimizer.state_dict(),
                        }, is_best, model_name)

    if args.test:
        print("\nTEST: ")
        util.validate(test_loader,
                      model,
                      criterion,
                      model_name + "_test",
                      args.epochs,
                      writer,
                      range_weight=args.range_weight)

        print("\nEvaluation on Train Set: ")
        util.validate(train_loader,
                      model,
                      criterion,
                      model_name + "_train",
                      args.epochs,
                      writer,
                      range_weight=args.range_weight)

    #############################################################################################################

    # test on only the basic sub-network (STO loss)
    model_pair = longi_models.ResNet_pair(model.modelA,
                                          args.num_date_diff_classes)
    torch.cuda.set_device(args.gpu)
    model_pair = model_pair.cuda(args.gpu)

    if args.test_pair:

        train_pair_dataset = long.LongitudinalDataset3DPair(
            args.datapath, traingroup, args.datapath + "/train_pair_list.csv",
            train_stages, test_augment, args.max_angle, args.rotate_prob,
            sample_size)

        train_pair_loader = torch.utils.data.DataLoader(
            train_pair_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True)

        print("\nEvaluation on Train Pair Set: ")

        util.validate_pair(train_pair_loader, model_pair, criterion,
                           model_name + "_train_pair_update", args.epochs,
                           writer, args.print_freq)

        test_pair_dataset = long.LongitudinalDataset3DPair(
            args.datapath, testgroup, args.datapath + "/test_pair_list.csv",
            test_stages, test_augment, args.max_angle, args.rotate_prob,
            sample_size)

        test_pair_loader = torch.utils.data.DataLoader(
            test_pair_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True)

        print("\nEvaluation on Test Pair Set: ")

        util.validate_pair(test_pair_loader, model_pair, criterion,
                           model_name + "_test_pair_update", args.epochs,
                           writer, args.print_freq)

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Пример #6
0
def main():

    global args
    args = parser.parse_args()

    # Default settings
    default_model = 'resnet18'
    default_bs = 16
    sz = 224

    # Take actions based upon initial arguments

    if args.gpu:
        # Check for GPU and CUDA libraries
        HAS_CUDA = torch.cuda.is_available()
        if not HAS_CUDA:
            sys.exit('No Cuda capable GPU detected')
    else:
        HAS_CUDA = False

    checkpoint_dir = args.save_dir

    # Define hyper-parameters

    # Note - allow dropout to be changed when resuming model

    tmp = args.dropout
    tmp = re.sub("[\[\]]", "", tmp)
    drops = [float(item) for item in tmp.split(',')]

    lr = args.learning_rate

    epochs = args.epochs

    # All arguments imported, will start to setup model depending upon whether restarting from checkpoint or
    # from scratch

    if args.resume:
        if os.path.isdir(args.resume):
            print('Loading checkpoint...')
            sol_mgr, pt_model = utility.load_checkpoint(
                args.resume, lr, HAS_CUDA)

    else:
        # Define hidden layer details (note - if resuming will continue with values used earlier

        tmp = args.hidden_units
        tmp = re.sub("[\[\]]", "", tmp)
        n_hid = [int(item) for item in tmp.split(',')]

        # check data directory exists and assign
        data_dir = args.data_directory
        # Check it exists
        if not os.path.exists(data_dir):
            sys.exit('Data directory does not exist')

        # Create model, datasets etc from scratch
        # create datasets and dataloaders
        phrases = ['train', 'valid', 'test']

        # Define data transforms
        data_transforms = {
            'train':
            transforms.Compose([
                transforms.RandomResizedCrop(sz),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]),
            'valid':
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(sz),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]),
            'test':
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(sz),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]),
        }

        bs = args.batch_size

        data = Data_Manager(data_dir, phrases, data_transforms, bs)

        # Load cat_to_name
        cat_to_name = utility.load_classes('cat_to_name.json')
        num_cat = len(cat_to_name)

        # Load pre-trained model
        if args.arch is not None:
            pt_model = args.arch
        else:
            pt_model = default_model
        model_pt = models.__dict__[pt_model](pretrained=True)
        num_ftrs = model_pt.fc.in_features

        # Create classifier model
        img_cl = Composite_Classifier(model_pt, n_hid, drops, num_cat)
        # Move to CUDA if available
        if HAS_CUDA:
            img_cl.cuda()

        # Define losses and hyper-parameters
        criterion = nn.CrossEntropyLoss()
        # Optimise just the parameters of the classifier layers
        optimizer_ft = optim.SGD(img_cl.cf_layers.parameters(),
                                 lr=lr,
                                 momentum=0.9)

        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft,
                                               step_size=7,
                                               gamma=0.1)

        # Freeze the pre-trained model layers
        for param in img_cl.model_pt.parameters():
            param.requires_grad = False

        # Create model manager to control training, validation, test and predict with the model and data
        sol_mgr = Solution_Manager(img_cl,
                                   criterion,
                                   optimizer_ft,
                                   exp_lr_scheduler,
                                   data,
                                   phrases,
                                   HAS_CUDA=HAS_CUDA)
        sol_mgr.model.class_to_idx = data.image_datasets['train'].class_to_idx

    # Train model
    sol_mgr.train(epochs=epochs)

    # Evaluate model against test set
    sol_mgr.test_with_dl()

    # Save Checkpoint
    utility.save_checkpoint(args.save_dir, sol_mgr, pt_model, HAS_CUDA)
Пример #7
0
                running_loss += loss.item()

                if steps % print_every == 0:
                    model.eval()

                    with torch.no_grad():
                        test_loss, accuracy = validation(
                            model, dataloader['test'], criterion, args.gpu)

                    print(
                        "Epoch: {}/{}.. ".format(e + 1, args.epochs),
                        "Training Loss: {:.3f}.. ".format(running_loss /
                                                          print_every),
                        "Validation Loss: {:.3f}.. ".format(
                            test_loss / len(dataloader['test'])),
                        "Test Accuracy: {:.3f}".format(
                            accuracy / len(dataloader['test'])))

                    running_loss = 0
                    model.train()

    model.optimizer_state_dict = optimizer.state_dict()
    model.class_to_idx = image_dataset['train'].class_to_idx
    return model, args.save_dir


if __name__ == "__main__":

    model, save_dir = train()
    utility.save_checkpoint(model, save_dir)
Пример #8
0
def train_save_model():

    args = parse_args()
    data_dir = 'flowers'
    train_dir = data_dir + '/train'
    valid_dir = data_dir + '/valid'
    test_dir = data_dir + '/test'

    # TODO: Define your transforms for the training, validation, and testing sets
    train_transforms = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_transforms = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    validation_transforms = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # TODO: Load the datasets with ImageFolder
    train_datasets = datasets.ImageFolder(train_dir,
                                          transform=train_transforms)
    test_datasets = datasets.ImageFolder(test_dir, transform=test_transforms)
    validate_datasets = datasets.ImageFolder(valid_dir,
                                             transform=validation_transforms)

    # TODO: Using the image datasets and the trainforms, define the dataloaders
    traindataloaders = torch.utils.data.DataLoader(train_datasets,
                                                   batch_size=64,
                                                   shuffle=True)
    testdataloaders = torch.utils.data.DataLoader(test_datasets,
                                                  batch_size=64,
                                                  shuffle=True)
    validatedataloaders = torch.utils.data.DataLoader(validate_datasets,
                                                      batch_size=64,
                                                      shuffle=True)

    model = getattr(models, args.arch)(pretrained=True)

    # Freeze parameters so we don't backprop through them
    for param in model.parameters():
        param.requires_grad = False

    if args.arch == "densenet121":
        classifier = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
                                   nn.Dropout(0.2), nn.Linear(512, 256),
                                   nn.ReLU(), nn.Dropout(0.2),
                                   nn.Linear(256, 102), nn.LogSoftmax(dim=1))
    elif args.arch == "vgg13":
        feature_num = model.classifier[0].in_features
        classifier = nn.Sequential(nn.Linear(feature_num, 1024), nn.ReLU(),
                                   nn.Dropout(0.2), nn.Linear(1024, 102),
                                   nn.LogSoftmax(dim=1))

    criterion = nn.NLLLoss()
    model.classifier = classifier
    optimizer = optim.Adam(model.classifier.parameters(),
                           lr=float(args.learning_rate))

    epochs = int(args.epochs)
    class_index = train_datasets.class_to_idx
    gpu = args.gpu

    train_model(model, epochs, gpu, criterion, optimizer, traindataloaders,
                validatedataloaders)

    model.class_to_idx = class_index
    path = args.save_dir  # get the new save location

    save_checkpoint(path, model, optimizer, args, classifier)
                default=0.001)
ap.add_argument('--dropout', dest="dropout", action="store", default=0.5)
ap.add_argument('--epochs', dest="epochs", action="store", type=int, default=1)
ap.add_argument('--arch',
                dest="arch",
                action="store",
                default="vgg16",
                type=str)
ap.add_argument('--hidden_units',
                type=int,
                dest="hidden_units",
                action="store",
                default=4096)

pa = ap.parse_args()
directory = pa.data_dir
checkpoint = pa.save_dir
lrate = pa.learning_rate
architecture = pa.arch
dropout = pa.dropout
hidden_units = pa.hidden_units
power = pa.gpu
epochs = pa.epochs

model, optimizer, criterion = utility.design_model(architecture, dropout,
                                                   lrate, hidden_units)

utility.train_model(model, criterion, optimizer, 1, power)

utility.save_checkpoint(model, optimizer, epochs)
Пример #10
0
def train(args, logger, device_ids):
    torch.manual_seed(7)
    writer = SummaryWriter()

    logger.info("Loading network")
    model = AdaMatting(in_channel=4)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0)
    if args.resume != "":
        ckpt = torch.load(args.resume)
        # for key, _ in ckpt.items():
        #     print(key)
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
    if args.cuda:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        device = torch.device("cuda:{}".format(device_ids[0]))
        if len(device_ids) > 1:
            logger.info("Loading with multiple GPUs")
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        # model = model.cuda(device=device_ids[0])
    else:
        device = torch.device("cpu")
    model = model.to(device)

    logger.info("Initializing data loaders")
    train_dataset = AdaMattingDataset(args.raw_data_path, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=16,
                                               pin_memory=True)
    valid_dataset = AdaMattingDataset(args.raw_data_path, "valid")
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=16,
                                               pin_memory=True)

    if args.resume != "":
        logger.info("Start training from saved ckpt")
        start_epoch = ckpt["epoch"] + 1
        cur_iter = ckpt["cur_iter"] + 1
        peak_lr = ckpt["peak_lr"]
        best_loss = ckpt["best_loss"]
    else:
        logger.info("Start training from scratch")
        start_epoch = 0
        cur_iter = 0
        peak_lr = args.lr
        best_loss = float('inf')

    avg_lo = AverageMeter()
    avg_lt = AverageMeter()
    avg_la = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        # Training
        torch.set_grad_enabled(True)
        model.train()
        for index, (img, gt) in enumerate(train_loader):
            cur_lr, peak_lr = lr_scheduler(optimizer=optimizer,
                                           cur_iter=cur_iter,
                                           peak_lr=peak_lr,
                                           end_lr=0.00001,
                                           decay_iters=args.decay_iters,
                                           decay_power=0.9,
                                           power=0.9)

            img = img.type(torch.FloatTensor).to(device)  # [bs, 4, 320, 320]
            gt_alpha = (gt[:,
                           0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(
                               device)  # [bs, 1, 320, 320]
            gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                device)  # [bs, 320, 320]

            optimizer.zero_grad()
            trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                img)
            L_overall, L_t, L_a = task_uncertainty_loss(
                pred_trimap=trimap_adaption,
                pred_trimap_argmax=t_argmax,
                pred_alpha=alpha_estimation,
                gt_trimap=gt_trimap,
                gt_alpha=gt_alpha,
                log_sigma_t_sqr=log_sigma_t_sqr,
                log_sigma_a_sqr=log_sigma_a_sqr)

            L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean()
            sigma_t, sigma_a = log_sigma_t_sqr.mean(), log_sigma_a_sqr.mean()

            optimizer.zero_grad()
            L_overall.backward()
            optimizer.step()

            avg_lo.update(L_overall.item())
            avg_lt.update(L_t.item())
            avg_la.update(L_a.item())

            if cur_iter % 10 == 0:
                logger.info(
                    "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}"
                    .format(epoch, index, len(train_loader), avg_lo.avg,
                            avg_lt.avg, avg_la.avg))
                writer.add_scalar("loss/L_overall", avg_lo.avg, cur_iter)
                writer.add_scalar("loss/L_t", avg_lt.avg, cur_iter)
                writer.add_scalar("loss/L_a", avg_la.avg, cur_iter)
                sigma_t = torch.exp(sigma_t / 2)
                sigma_a = torch.exp(sigma_a / 2)
                writer.add_scalar("other/sigma_t", sigma_t.item(), cur_iter)
                writer.add_scalar("other/sigma_a", sigma_a.item(), cur_iter)
                writer.add_scalar("other/lr", cur_lr, cur_iter)

                avg_lo.reset()
                avg_lt.reset()
                avg_la.reset()

            cur_iter += 1

        # Validation
        logger.info("Validating after the {}th epoch".format(epoch))
        avg_loss = AverageMeter()
        avg_l_t = AverageMeter()
        avg_l_a = AverageMeter()
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)
        model.eval()
        with tqdm(total=len(valid_loader)) as pbar:
            for index, (img, gt) in enumerate(valid_loader):
                img = img.type(torch.FloatTensor).to(
                    device)  # [bs, 4, 320, 320]
                gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(
                    torch.FloatTensor).to(device)  # [bs, 1, 320, 320]
                gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(
                    device)  # [bs, 320, 320]

                trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(
                    img)
                L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(
                    pred_trimap=trimap_adaption,
                    pred_trimap_argmax=t_argmax,
                    pred_alpha=alpha_estimation,
                    gt_trimap=gt_trimap,
                    gt_alpha=gt_alpha,
                    log_sigma_t_sqr=log_sigma_t_sqr,
                    log_sigma_a_sqr=log_sigma_a_sqr)

                L_overall_valid, L_t_valid, L_a_valid = L_overall_valid.mean(
                ), L_t_valid.mean(), L_a_valid.mean()

                avg_loss.update(L_overall_valid.item())
                avg_l_t.update(L_t_valid.item())
                avg_l_a.update(L_a_valid.item())

                if index == 0:
                    trimap_adaption_res = (t_argmax.type(torch.FloatTensor) /
                                           2).unsqueeze(dim=1)
                    trimap_adaption_res = torchvision.utils.make_grid(
                        trimap_adaption_res, normalize=False, scale_each=True)
                    writer.add_image('valid_image/trimap_adaptation',
                                     trimap_adaption_res, cur_iter)
                    alpha_estimation_res = torchvision.utils.make_grid(
                        alpha_estimation, normalize=True, scale_each=True)
                    writer.add_image('valid_image/alpha_estimation',
                                     alpha_estimation_res, cur_iter)

                pbar.update()

        logger.info("Average loss overall: {:.4e}".format(avg_loss.avg))
        logger.info("Average loss of trimap adaptation: {:.4e}".format(
            avg_l_t.avg))
        logger.info("Average loss of alpha estimation: {:.4e}".format(
            avg_l_a.avg))
        writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter)
        writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter)
        writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter)

        is_best = avg_loss.avg < best_loss
        best_loss = min(avg_loss.avg, best_loss)
        if is_best or args.save_ckpt:
            if not os.path.exists("ckpts"):
                os.makedirs("ckpts")
            save_checkpoint(ckpt_path=args.ckpt_path,
                            is_best=is_best,
                            logger=logger,
                            model=model,
                            optimizer=optimizer,
                            epoch=epoch,
                            cur_iter=cur_iter,
                            peak_lr=peak_lr,
                            best_loss=best_loss)

    writer.close()
Пример #11
0
model = utility.torch_model(args.arch)
for param in model.parameters():
    param.requires_grad = False
#params are now frozen so that we do not backprop thru them again

#calculate input size into the network classifier
input_size = utility.get_input_size(model, args.arch)

model.classifier = network.Network(input_size,
                                   args.output_size, [args.hidden_units],
                                   drop_p=0.35)

#define the loss function and the optimization parameters
criterion = nn.NLLLoss(
)  #want nllloss because we do the logsoftmax as our output activation
optimizer = optim.SGD(model.classifier.parameters(), lr=args.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

#train model
network.train_network(model, trainloader, validloader, args.epochs, 32,
                      criterion, optimizer, scheduler, args.gpu)

#test model
test_accuracy, test_loss = network.check_accuracy_loss(model, testloader,
                                                       criterion, args.gpu)
print("\n ---\n Test Accuracy: {:.2f} %".format(test_accuracy * 100),
      "Test Loss: {}".format(test_loss))

#save network to checkpoint
utility.save_checkpoint(model, train_data, optimizer, args.save_dir, args.arch)
Пример #12
0
from get_input_args import get_input_args
import numpy as np
from utility import set_std_dict, set_model, save_checkpoint
from training import train_model, test_model
import json

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

in_arg = get_input_args()
std_dict = set_std_dict(in_arg.data_dir)
model = set_model(in_arg.arch, in_arg.gpu, in_arg.hidden_units, std_dict)

epochs = in_arg.epochs
learning_rate = in_arg.learning_rate

criterion = nn.NLLLoss()
if in_arg.arch == "vgg16":
    optimizer = optim.Adam(model.classifier.parameters(), lr=learning_rate)
else:
    optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)

train_model(model, criterion, optimizer, epochs, in_arg.gpu, std_dict)
test_model(model, in_arg.gpu, std_dict)

save_checkpoint(model, in_arg.arch, optimizer, epochs, in_arg.save_dir,
                std_dict)