Exemple #1
0
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=True):
        self.patience = patience
        self.verbose = verbose
        self.best_score = None
        self.delta = delta
        self.counter = 0
        self.early_stop = False
        self.checkpoint = Checkpoint("model")

    def __call__(self, model, val_loss, optimizer, epoch=None):
        score = -val_loss
        if self.best_score == None or score > self.best_score + self.delta:
            self.counter = 0
            self.best_score = score
            self.checkpoint.save(model, optimizer, val_loss, epoch)
        elif score <= self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

    # reset the counter
    def reset_counter(self):
        self.counter = 0

    # check if the early stopping criteria is meet
    def get_early_stop(self):
        return self.early_stop

    # return class checkpoint (for loading & saving model)
    def get_checkpoint(self):
        return self.checkpoint
Exemple #2
0
 def __init__(self, patience=5, delta=0, verbose=True):
     self.patience = patience
     self.verbose = verbose
     self.best_score = None
     self.delta = delta
     self.counter = 0
     self.early_stop = False
     self.checkpoint = Checkpoint("model")
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(
        opt.exp_dir, opt.exp_id, opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_hg = Checkpoint()
    # visualizer = Visualizer(opt)
    # log_name = opt.resume_prefix_pose + 'log.txt'
    # visualizer.log_path = sr_pretrain_dir + '/' + log_name
    train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'
    # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'

    if opt.dataset == 'mpii':
        num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        # exit()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    print 'collecting training distributions ...\n'
    train_distri_list = collect_train_valid_data(train_distri_path,
                                                 train_distri_path_2,
                                                 hg,
                                                 opt,
                                                 is_train=True)

    print 'collecting validation distributions ...\n'
    val_distri_list = collect_train_valid_data(val_distri_path,
                                               val_distri_path_2,
                                               hg,
                                               opt,
                                               is_train=False)
Exemple #4
0
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistory()
    checkpoint = Checkpoint(opt)
    visualizer = Visualizer(opt)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
    cudnn.benchmark = True
    """build graph"""
    net = model.CreateNet(opt)
    """optimizer"""
    optimizer = model.CreateAdamOptimizer(opt, net)
    #net = torch.nn.DataParallel(net).cuda()
    net.cuda()
    """optionally resume from a checkpoint"""
    checkpoint.load_checkpoint(net, optimizer, train_history)
    """load data"""
    train_list = os.path.join(opt.data_dir, opt.train_list)
    train_loader = torch.utils.data.DataLoader(ImageLoader(
        train_list, transforms.ToTensor(), is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_list = os.path.join(opt.data_dir, opt.val_list)
    val_loader = torch.utils.data.DataLoader(ImageLoader(val_list,
                                                         transforms.ToTensor(),
                                                         is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)
    """training and validation"""
    for epoch in range(opt.resume_epoch, opt.nEpochs):
        model.AdjustLR(opt, optimizer, epoch)

        # train for one epoch
        train_loss_det, train_loss_reg, tran_loss = \
            train(train_loader, net, optimizer, epoch, visualizer)

        # evaluate on validation set
        val_loss_det, val_loss_reg, val_loss, det_rmse, reg_rmse = \
            validate(val_loader, net, epoch, visualizer, is_show=False)

        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', opt.lr)])
        loss = OrderedDict([('train_loss_det', train_loss_det),
                            ('train_loss_reg', train_loss_reg),
                            ('val_loss_det', val_loss_det),
                            ('val_loss_reg', val_loss_reg)])
        rmse = OrderedDict([('det_rmse', det_rmse), ('val_rmse', reg_rmse)])
        train_history.update(e, lr, loss, rmse)
        checkpoint.save_checkpoint(net, optimizer, train_history)
        visualizer.plot_train_history(train_history)

        # plot best validation
        if train_history.is_best:
            visualizer.imgpts_win_id = 4
            validate(val_loader, net, epoch, visualizer, is_show=True)
Exemple #5
0
def main():

    par = Params(sys.argv)
    random.seed(par.seed)
    torch.manual_seed(par.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(par.seed)

    if par.trn and par.val:
        chk = Checkpoint(par.dir)

        if chk.contains_model:  ####### resume training ####################################
            cfg, mod, opt = chk.load(par)  ### also moves to GPU if cfg.cuda
            #            cfg.update_par(par) ### updates par in cfg
            print_time('Learning [resume It={}]...'.format(cfg.n_iters_sofar))

        else:  ######################## training from scratch ##############################
            cfg = Config(par)  ### reads cfg and par (reads vocabularies)
            mod = Model(cfg)
            if cfg.cuda: mod.cuda()  ### moves to GPU
            opt = Optimizer(cfg, mod)  #build Optimizer
            print_time('Learning [from scratch]...')

        trn = Dataset(par.trn,
                      cfg.svoc,
                      cfg.tvoc,
                      par.batch_size,
                      par.max_src_len,
                      par.max_tgt_len,
                      do_shuffle=True,
                      do_filter=True,
                      is_test=False)
        val = Dataset(par.val,
                      cfg.svoc,
                      cfg.tvoc,
                      par.batch_size,
                      par.max_src_len,
                      par.max_tgt_len,
                      do_shuffle=True,
                      do_filter=True,
                      is_test=True)
        Training(cfg, mod, opt, trn, val, chk)

    elif par.tst:  #################### inference ##########################################
        chk = Checkpoint()
        cfg, mod, opt = chk.load(par, par.chk)
        #        cfg.update_par(par) ### updates cfg options with pars
        tst = Dataset(par.tst,
                      cfg.svoc,
                      cfg.tvoc,
                      par.batch_size,
                      0,
                      0,
                      do_shuffle=False,
                      do_filter=False,
                      is_test=True)
        print_time('Inference [model It={}]...'.format(cfg.n_iters_sofar))
        Inference(cfg, mod, tst)
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistory()
    checkpoint = Checkpoint(opt)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    """Architecture"""
    net = MakeLinearModel(1024, 16)

    net = torch.nn.DataParallel(net).cuda()
    checkpoint.load_checkpoint(net, train_history, '/best-single.pth.tar')
    """Uploading Mean and SD"""
    path_to_data = '.../multi-view-pose-estimation/dataset/'

    #mean and sd of 2d poses in training dataset
    Mean_2D = np.loadtxt(path_to_data + 'Mean_2D.txt')
    Mean_2D = Mean_2D.astype('float32')
    Mean_2D = torch.from_numpy(Mean_2D)

    Mean_Delta = np.loadtxt(path_to_data + 'Mean_Delta.txt')
    Mean_Delta = Mean_Delta.astype('float32')
    Mean_Delta = torch.from_numpy(Mean_Delta)
    Mean_Delta = torch.autograd.Variable(Mean_Delta.cuda(async=True),
                                         requires_grad=False)

    Mean_3D = np.loadtxt(path_to_data + 'Mean_3D.txt')
    Mean_3D = Mean_3D.astype('float32')
    Mean_3D = torch.from_numpy(Mean_3D)
    Mean_3D = torch.autograd.Variable(Mean_3D.cuda(async=True),
                                      requires_grad=False)

    SD_2D = np.loadtxt(path_to_data + 'SD_2D.txt')
    SD_2D = SD_2D.astype('float32')
    SD_2D = torch.from_numpy(SD_2D)

    SD_Delta = np.loadtxt(path_to_data + 'SD_Delta.txt')
    SD_Delta = SD_Delta.astype('float32')
    SD_Delta = torch.from_numpy(SD_Delta)
    SD_Delta = torch.autograd.Variable(SD_Delta.cuda(async=True),
                                       requires_grad=False)

    SD_3D = np.loadtxt(path_to_data + 'SD_3D.txt')
    SD_3D = SD_3D.astype('float32')
    SD_3D = torch.from_numpy(SD_3D)
    SD_3D = torch.autograd.Variable(SD_3D.cuda(async=True),
                                    requires_grad=False)
    """Loading Data"""
    train_list = 'train_list_4view.txt'
    train_loader = torch.utils.data.DataLoader(data4view.PtsList(
        train_list, is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_list = 'valid_list_4view.txt'
    val_loader = torch.utils.data.DataLoader(data4view.PtsList(val_list,
                                                               is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    demo_list = 'demo_list_4view.txt'
    demo_loader = torch.utils.data.DataLoader(data4view.PtsList(
        demo_list, is_train=False),
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=opt.nThreads,
                                              pin_memory=True)
    """Optimizer"""
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opt.lr,
                                 betas=(0.9, 0.999),
                                 weight_decay=0)
    """Validation"""
    if evaluate_mode:
        # evaluate on validation set
        checkpoint.load_checkpoint(net, train_history, '/best-multi.pth.tar')
        val_loss, val_pckh = validate(val_loader, net, Mean_2D, Mean_Delta,
                                      Mean_3D, SD_2D, SD_Delta, SD_3D, 0, opt)
        return
    """Demo"""
    if demo_mode:
        # Grab a random batch to visualize
        checkpoint.load_checkpoint(net, train_history, '/best-multi.pth.tar')
        demo(demo_loader, net, Mean_2D, Mean_Delta, Mean_3D, SD_2D, SD_Delta,
             SD_3D, 0, opt)
        return
    """Training"""
    for epoch in range(0, opt.nEpochs):
        adjust_learning_rate(optimizer, epoch, opt.lr)

        # train for one epoch
        train_loss = train(train_loader, net, Mean_2D, Mean_Delta, Mean_3D,
                           SD_2D, SD_Delta, SD_3D, optimizer, epoch, opt)

        # evaluate on validation set
        val_loss, val_pckh = validate(val_loader, net, Mean_2D, Mean_Delta,
                                      Mean_3D, SD_2D, SD_Delta, SD_3D, epoch,
                                      opt)

        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', opt.lr)])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, train_history, 'best-multi.pth.tar')
Exemple #7
0
    def train(
        self, train_loader, val_loader=None, max_epochs=1000, enable_early_stopping=True
    ):
        if val_loader is None:
            enable_early_stopping = False

        print()
        print("-" * 2, "Training Setup", "-" * 2)
        print(f"Maximum Epochs: {max_epochs}")
        print(f"Enable Early Stoping: {enable_early_stopping}")
        print("-" * 20)
        print("*Start Training.")

        # model setup
        self.model.train().to(self.device)
        if self.multi_gpus and torch.cuda.device_count() > 1:
            print(f"*Using {torch.cuda.device_count()} GPUs!")
            self.model = nn.DataParallel(self.model)

        # early stopping instance
        if enable_early_stopping:
            if self.early_stopping is None:
                self.early_stopping = EarlyStopping(patience=5)
            else:
                self.early_stopping.reset_counter()

        # training start!
        for epoch in range(1, max_epochs + 1):
            running_loss = 0.0

            for step, data in enumerate(train_loader, start=1):
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # Zero the parameter gradients
                self.optimizer.zero_grad()
                # forward + backward + optimize
                outputs = self.model(inputs)
                loss = self.loss_func(outputs, labels)
                loss.backward()
                self.optimizer.step()
                # print statistics
                running_loss += loss.item()

                if step % 100 == 0 or step == len(train_loader):
                    print(
                        f"[{epoch}/{max_epochs}, {step}/{len(train_loader)}] loss: {running_loss / step :.3f}"
                    )

            # train & validation loss
            train_loss = running_loss / len(train_loader)
            if val_loader is None:
                print(f"train loss: {train_loss:.3f}")
            else:
                # FIXME: fixed the problem that first validation is not correct
                val_loss = self.validation(val_loader)
                print(f"train loss: {train_loss:.3f}, val loss: {val_loss:.3f}")

                if enable_early_stopping:
                    self.early_stopping(self.model, val_loss, self.optimizer)
                    if self.early_stopping.get_early_stop() == True:
                        print("*Early Stopping.")
                        break

        print("*Finished Training!")
        if enable_early_stopping:
            checkpoint = self.early_stopping.get_checkpoint()
        else:
            checkpoint = Checkpoint()
            checkpoint.tmp_save(self.model, self.optimizer, epoch, val_loss)
        self.checkpoint = checkpoint
        self.model = checkpoint.load(self.model, self.optimizer)["model"]
        return self.model
def main():
	# global args, best_prec1
	args = parser.parse_args()
	print('\n====> Input Arguments')
	print(args)
	# Tensorboard writer.
	global writer
	writer = SummaryWriter(log_dir=args.result_path)


	# Create dataloader.
	print '\n====> Creating dataloader...'
	train_loader = get_train_loader(args)
	test_loader = get_test_loader(args)

	# Load First Glance network.
	print '====> Loading the network...'
	model = First_Glance(num_classes=args.num_classes, pretrained=True)

	"""Load checkpoint and weight of network.
	"""
	global cp_recorder
	if args.checkpoint_dir:
		cp_recorder = Checkpoint(args.checkpoint_dir, args.checkpoint_name)
		cp_recorder.load_checkpoint(model)
	

	model.cuda()
	criterion = nn.CrossEntropyLoss().cuda()
	cudnn.benchmark = True
	optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum)
			
	# Train first-glance model.
	print '====> Training...'
	for epoch in range(cp_recorder.contextual['b_epoch'], args.epoch):
		_, _, prec_tri, rec_tri, ap_tri = train_eval(train_loader, test_loader, model, criterion, optimizer, args, epoch)
		top1_avg_val, loss_avg_val, prec_val, rec_val, ap_val = validate_eval(test_loader, model, criterion, args, epoch)

		# Print result.
		writer.add_scalars('mAP (per epoch)', {'train': np.nan_to_num(ap_tri).mean()}, epoch)
		writer.add_scalars('mAP (per epoch)', {'valid': np.nan_to_num(ap_val).mean()}, epoch)
		print('\n====> Scores')
		print('[Epoch {0}]:\n'
			'  Train:\n'
			'    Prec@1 {1}\n'
			'    Recall {2}\n'
			'    AP {3}\n'
			'    mAP {4:.3f}\n'
			'  Valid:\n'
			'    Prec@1 {5}\n'
			'    Recall {6}\n'
			'    AP {7}\n'
			'    mAP {8:.3f}\n'.format(epoch, 
				prec_tri, rec_tri, ap_tri, np.nan_to_num(ap_tri).mean(),
				prec_val, rec_val, ap_val, np.nan_to_num(ap_val).mean()))
		

		# Record.
		writer.add_scalars('Loss (per batch)', {'valid': loss_avg_val}, (epoch+1)*len(train_loader))
		writer.add_scalars('Prec@1 (per batch)', {'valid': top1_avg_val}, (epoch+1)*len(train_loader))
		writer.add_scalars('mAP (per batch)', {'valid': np.nan_to_num(ap_val).mean()}, (epoch+1)*len(train_loader))

		# Save checkpoint.
		cp_recorder.record_contextual({'b_epoch': epoch+1, 'b_batch': -1, 'prec': top1_avg_val, 'loss': loss_avg_val, 
			'class_prec': prec_val, 'class_recall': rec_val, 'class_ap': ap_val, 'mAP': np.nan_to_num(ap_val).mean()})
		cp_recorder.save_checkpoint(model)
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    # layer_num = 2
    net = create_cu_net(neck_size=4,
                        growth_rate=32,
                        init_chan_num=128,
                        num_classes=num_classes,
                        layer_num=opt.layer_num,
                        max_link=1,
                        inter_loss_num=opt.layer_num)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print 'number of params: ', num1
    # print 'number of trainalbe params: ', num2
    # print 'number of conv params: ', num3
    # torch.save(net.state_dict(), 'test-model-size.pth.tar')
    # exit()
    # device = torch.device("cuda:0")
    # net = net.to(device)
    net = torch.nn.DataParallel(net).cuda()
    global quan_op
    quan_op = QuanOp(net)
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        # checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.save_prefix = exp_dir + '/'
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        opt.lr = optimizer.param_groups[0]['lr']
        resume_log = True
    else:
        checkpoint.save_prefix = exp_dir + '/'
        resume_log = False
    print 'save prefix: ', checkpoint.save_prefix
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)
    """optimizer"""
    # optimizer = torch.optim.SGD( net.parameters(), lr=opt.lr,
    #                             momentum=opt.momentum,
    #                             weight_decay=opt.weight_decay )
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=opt.lr, alpha=0.99,
    #                                 eps=1e-8, momentum=0, weight_decay=0)
    print type(optimizer)
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    logger = Logger(os.path.join(opt.exp_dir, opt.exp_id,
                                 'training-summary.txt'),
                    title='training-summary',
                    resume=resume_log)
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            joint_flip_index, num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   joint_flip_index,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        # visualizer.plot_train_history(train_history)
        logger.append([
            epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss,
            train_pckh, val_pckh
        ])
    logger.close()
Exemple #10
0
def main():
    opt = TrainOptions().parse() 
    train_history = TrainHistoryFace()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    num_classes = opt.class_num

    if not opt.slurm:
        os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    layer_num = opt.layer_num
    order     = opt.order
    net = create_cu_net(neck_size= 4, growth_rate= 32, init_chan_num= 128, 
                class_num= num_classes, layer_num= layer_num, order= order, 
                loss_num= layer_num, use_spatial_transformer= opt.stn, 
                mlp_tot_layers= opt.mlp_tot_layers, mlp_hidden_units= opt.mlp_hidden_units,
                get_mean_from_mlp= opt.get_mean_from_mlp)

    # Load the pre-trained model
    saved_wt_file = opt.saved_wt_file
    if saved_wt_file == "":
        print("=> Training from scratch")
    else:
        print("=> Loading weights from " + saved_wt_file)
        checkpoint_t = torch.load(saved_wt_file)
        state_dict = checkpoint_t['state_dict']

        tt_names=[]
        for names in net.state_dict():
            tt_names.append(names)

        for name, param in state_dict.items():
            name = name[7:]
            if name not in net.state_dict():
                print("=> not load weights '{}'".format(name))
                continue
            if isinstance(param, Parameter):
                param = param.data
            if (net.state_dict()[name].shape[0] == param.shape[0]):
                net.state_dict()[name].copy_(param)
            else:
                print("First dim different. Not loading weights {}".format(name))


    if (opt.freeze):
        print("\n\t\tFreezing basenet parameters\n")
        for param in net.parameters():
            param.requires_grad = False
        """
        for i in range(layer_num):
            net.choleskys[i].fc_1.bias.requires_grad   = True
            net.choleskys[i].fc_1.weight.requires_grad = True
            net.choleskys[i].fc_2.bias.requires_grad   = True
            net.choleskys[i].fc_2.weight.requires_grad = True
            net.choleskys[i].fc_3.bias.requires_grad   = True
            net.choleskys[i].fc_3.weight.requires_grad = True
        """

        net.cholesky.fc_1.bias.requires_grad   = True
        net.cholesky.fc_1.weight.requires_grad = True
        net.cholesky.fc_2.bias.requires_grad   = True
        net.cholesky.fc_2.weight.requires_grad = True
        net.cholesky.fc_3.bias.requires_grad   = True
        net.cholesky.fc_3.weight.requires_grad = True

    else:
        print("\n\t\tNot freezing anything. Tuning every parameter\n")
        for param in net.parameters():
            param.requires_grad = True

    net = torch.nn.DataParallel(net).cuda() # use multiple GPUs

    # Optimizer
    if opt.optimizer == "rmsprop":
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, net.parameters()), lr=opt.lr, alpha=0.99,
                                        eps=1e-8, momentum=0, weight_decay=0)
    elif opt.optimizer == "adam":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=opt.lr)
    else:
        print("Unknown Optimizer. Aborting!!!")
        sys.exit(0)
    print type(optimizer)

    # Optionally resume from a checkpoint
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print("Save prefix                           = {}".format(checkpoint.save_prefix))

    # Load data
    json_path  = opt.json_path
    train_json = opt.train_json
    val_json   = opt.val_json

    print("Path added to each image path in JSON = {}".format(json_path))
    print("Train JSON path                       = {}".format(train_json))
    print("Val JSON path                         = {}".format(val_json))

    if opt.bulat_aug:
        # Use Bulat et al Augmentation Scheme
        train_loader = torch.utils.data.DataLoader(
             FACE(train_json, json_path, is_train= True, scale_factor= 0.2, rot_factor= 50, use_occlusion= True, keep_pts_inside= True),
             batch_size=opt.bs, shuffle= True,
             num_workers=opt.nThreads, pin_memory= True)
    else:
        train_loader = torch.utils.data.DataLoader(
             FACE(train_json, json_path, is_train= True, keep_pts_inside= True),
             batch_size=opt.bs, shuffle= True,
             num_workers=opt.nThreads, pin_memory= True)

    val_loader = torch.utils.data.DataLoader(
         FACE(val_json, json_path, is_train=False),
         batch_size=opt.bs, shuffle=False,
         num_workers=opt.nThreads, pin_memory=True)

    logger = Logger(os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix+'face-training-log.txt'),
    title='face-training-summary')
    logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train RMSE', 'Val RMSE', 'Train RMSE Box', 'Val RMSE Box', 'Train RMSE Meta', 'Val RMSE Meta'])
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id, 'val_log.txt')
        val_loss, val_rmse, predictions = validate(val_loader, net,
                train_history.epoch[-1]['epoch'], visualizer, num_classes, flip_index)
        checkpoint.save_preds(predictions)
        return

    global weights_HG
    weights_HG  = [float(x) for x in opt.hg_wt.split(",")] 

    if opt.is_covariance:
        print("Covariance used from the heatmap")
    else:
        print("Covariance calculated from MLP")

    if opt.stn:
        print("Using spatial transformer on heatmaps")
    print ("Postprocessing applied                = {}".format(opt.pp)) 
    if (opt.smax):
        print("Scaled softmax used with tau          = {}".format(opt.tau))
    else:
        print("No softmax used")

    print("Individual Hourglass loss weights")
    print(weights_HG)
    print("wt_MSE (tradeoff between GLL and MSE in each hourglass)= " + str(opt.wt_mse))
    print("wt_gauss_regln (tradeoff between GLL and Gaussian Regularisation in each hourglass)= " + str(opt.wt_gauss_regln))

    if opt.bulat_aug:
        print("Using Bulat et al, ICCV 2017 Augmentation Scheme")

    print("Using Learning Policy {}".format(opt.lr_policy))
    chosen_lr_policy = dict_of_functions[opt.lr_policy]

    # Optionally resume from a checkpoint
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    # Training and validation
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    train_loss_orig_epoch   = []
    train_loss_gau_t1_epoch = []
    train_loss_gau_t2_epoch = []
    train_nme_orig_epoch    = []
    train_nme_gau_epoch     = []
    train_nme_new_epoch     = []

    val_loss_orig_epoch     = []
    val_loss_gau_t1_epoch   = []
    val_loss_gau_t2_epoch   = []
    val_nme_orig_epoch      = []
    val_nme_gau_epoch       = []
    val_nme_new_epoch       = []

    for epoch in range(start_epoch, opt.nEpochs):
        chosen_lr_policy(opt, optimizer, epoch)
        # Train for one epoch
        train_loss, train_loss_mse,train_loss_gau_t1, train_loss_gau_t2,train_rmse_orig, train_rmse_gau, train_rmse_new_gd_box, train_rmse_new_meta_box  = train(train_loader, net, optimizer, epoch, visualizer, opt)
        #train_loss_gau_epoch.append(train_loss_gau)
        train_loss_gau_t1_epoch.append(train_loss_gau_t1)
        train_loss_gau_t2_epoch.append(train_loss_gau_t2)
        train_nme_orig_epoch.append(train_rmse_orig)
        train_nme_gau_epoch.append(train_rmse_gau)
        train_loss_orig_epoch.append(train_loss_mse)

        # Evaluate on validation set
        val_loss, val_loss_mse, val_loss_gau_t1, val_loss_gau_t2 , val_rmse_orig, val_rmse_gau, val_rmse_new_gd_box, val_rmse_new_meta_box, predictions= validate(val_loader, net, epoch, visualizer, opt, num_classes, flip_index)
        val_loss_orig_epoch.append(val_loss_mse)
        val_loss_gau_t1_epoch.append(val_loss_gau_t1)
        val_loss_gau_t2_epoch.append(val_loss_gau_t2)
        val_nme_orig_epoch.append(val_rmse_orig)
        val_nme_gau_epoch.append(val_rmse_gau)

        # Update training history
        e = OrderedDict( [('epoch', epoch)] )
        lr = OrderedDict( [('lr', optimizer.param_groups[0]['lr'])] )
        loss = OrderedDict( [('train_loss', train_loss),('val_loss', val_loss)] )
        rmse = OrderedDict( [('val_rmse', val_rmse_gau)] )
        train_history.update(e, lr, loss, rmse)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        visualizer.plot_train_history_face(train_history)
        logger.append([epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss, train_rmse_gau, val_rmse_gau, train_rmse_new_gd_box, val_rmse_new_gd_box, train_rmse_new_meta_box, val_rmse_new_meta_box])

    logger.close()
def train():
    args = configs.get_args()
    use_cuda = args.use_cuda and torch.cuda.is_available()

    # prepare dataset
    dataset = libs.dataset.MyDataset(min_length=args.min_length)
    voc_size = dataset.get_voc_size()
    dataloader = DataLoader(dataset, 1, True, drop_last=False)

    # prepare model
    model = models.TopModuleCNN(voc_size, output_channel=args.output_channel)
    if use_cuda:
        model = model.cuda()

    # load pretrained if asked
    if args.resume:
        checkpoint_path = Checkpoint.get_certain_checkpoint(
            "./experiment/cnn_net", "best")
        resume_checkpoint = Checkpoint.load(checkpoint_path)
        model = resume_checkpoint.model
        optimizer = resume_checkpoint.optimizer

        resume_optim = optimizer.optimizer
        defaults = resume_optim.param_groups[0]
        defaults.pop('params', None)
        optimizer.optimizer = resume_optim.__class__(model.parameters(),
                                                     **defaults)

        start_epoch = resume_checkpoint.epoch
        max_ans_acc = resume_checkpoint.max_ans_acc
    else:
        start_epoch = 1
        max_ans_acc = 0
        optimizer = NoamOpt(
            512, 1, 2000,
            optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    # define loss
    loss = nn.CrossEntropyLoss(weight=torch.tensor([1., 4.]))
    if use_cuda:
        loss = loss.cuda()

    # training
    for i in range(start_epoch, args.epochs):
        # test the model
        if args.resume:
            test_ans_acc = max_ans_acc
        else:
            test_ans_acc = test(DataLoader(dataset, 1, True, drop_last=False),
                                model, i)
        print('For EPOCH {}, total f1: {:.2f}'.format(i, test_ans_acc))

        # calculate loss
        j = 0
        los1 = []
        for _, data in enumerate(dataloader):
            j += 1
            x = data['que'].long()
            y = data['ans'].long()
            res = data['res'].long()
            if use_cuda:
                x, y, res = x.cuda(), y.cuda(), res.cuda()
            res_pred = model(x, y)

            los1.append(loss(res_pred, res).unsqueeze(0))

            # apply gradient
            if j % args.batch_size == 0:
                los1 = torch.cat(los1)
                los = los1.sum()
                model.zero_grad()
                los.backward()
                optimizer.step()
                los1 = []
                print('EPOCH: {}, {} / {}====> LOSS: {:.2f}'.format(
                    i, j // args.batch_size,
                    dataloader.__len__() // args.batch_size,
                    los.item() / args.batch_size))

        # save checkpoint
        if test_ans_acc > max_ans_acc:
            max_ans_acc = test_ans_acc
            th_checkpoint = Checkpoint(model=model,
                                       optimizer=optimizer,
                                       epoch=i,
                                       max_ans_acc=max_ans_acc)
            th_checkpoint.save_according_name("./experiment/cnn_net", 'best')
Exemple #12
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--device',
                        type=str,
                        default='gpu',
                        help='For cpu: \'cpu\', for gpu: \'gpu\'')
    parser.add_argument('--chunk_size',
                        type=int,
                        default=36,
                        help='chunk size(sequence length)')
    parser.add_argument('--step_size',
                        type=int,
                        default=1,
                        help='sequence split step')
    parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
    parser.add_argument('--weight_decay',
                        type=argtype.check_float,
                        default='1e-2',
                        help='weight_decay')
    parser.add_argument('--epoch',
                        type=argtype.epoch,
                        default='inf',
                        help='the number of epoch for training')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='size of batches for training')
    parser.add_argument('--val_ratio',
                        type=float,
                        default=.3,
                        help='validation set ratio')
    parser.add_argument('--model_name',
                        type=str,
                        default='main_model',
                        help='model name to save')
    parser.add_argument('--transfer',
                        type=argtype.boolean,
                        default=False,
                        help='whether fine tuning or not')
    parser.add_argument('--oversample_times',
                        type=int,
                        default=30,
                        help='the times oversampling times for fine tuning')
    parser.add_argument('--patience',
                        type=int,
                        default=20,
                        help='patience for early stopping')
    parser.add_argument('--c_loss',
                        type=argtype.boolean,
                        default=True,
                        help='whether using custom loss or not')
    parser.add_argument('--predict',
                        type=argtype.boolean,
                        default=False,
                        help='predict and save csv file or not')
    parser.add_argument('--filename',
                        type=str,
                        default='submission',
                        help='csv file name to save predict result')
    parser.add_argument('--Y_list',
                        type=argtype.str_to_list,
                        default='Y12,Y15',
                        help='target Y for pre-training')
    parser.add_argument('--window_size',
                        type=int,
                        default=1,
                        help='window size for moving average')
    parser.add_argument('--attention',
                        type=argtype.boolean,
                        default=True,
                        help='select model using attention mechanism')

    args = parser.parse_args()

    data_dir = './data'

    if args.device == 'gpu':
        args.device = 'cuda'
    device = torch.device(args.device)

    chunk_size = args.chunk_size
    step_size = args.step_size
    lr = args.lr
    weight_decay = args.weight_decay
    EPOCH = args.epoch
    batch_size = args.batch_size
    val_ratio = args.val_ratio
    model_name = args.model_name
    transfer_learning = args.transfer
    times = args.oversample_times
    patience = args.patience
    c_loss = args.c_loss
    pred = args.predict
    filename = args.filename
    Y_list = args.Y_list
    window_size = args.window_size
    attention = args.attention

    params = {
        'chunk_size': chunk_size,
        'step_size': step_size,
        'learning_rate': lr,
        'weight_decay': weight_decay,
        'epoch size': EPOCH,
        'batch_size': batch_size,
        'valid_ratio': val_ratio,
        'model_name': model_name,
        'transfer_learning': transfer_learning,
        'oversample_times': times,
        'early_stopping_patience': patience,
        'c_loss': c_loss,
        'pred': pred,
        'filename': filename,
        'Y_list': Y_list,
        'window_size': window_size,
        'attention': attention
    }

    Y = ''
    for y in Y_list:
        Y += y

    model_name = f'{model_name}/{Y}'

    Dataframe = dataframe.Dataframe(data_dir=data_dir)
    input_size = len(Dataframe.feature_cols)

    if attention:
        model = regressor.Attention_Regressor(input_size).to(device)
    else:
        model = regressor.BiLSTM_Regressor().to(device)

    checkpoint = Checkpoint(model_name=model_name,
                            transfer_learning=transfer_learning)
    early_stopping = Early_stopping(patience=patience)
    vis = Custom_Visdom(model_name, transfer_learning)
    vis.print_params(params)

    if transfer_learning:

        dataset_list = []

        if attention:

            pre_df = Dataframe.get_pretrain_df()\
                    .iloc[-chunk_size+1:][Dataframe.feature_cols]
            df = Dataframe.get_y18_df()

            df = pd.concat([pre_df, df], axis=0)

        else:
            df = Dataframe.get_y18_df()

        train_dataset = datasets.CustomSequenceDataset(chunk_size=chunk_size,
                                                       df=df,
                                                       Y='Y18',
                                                       step_size=step_size,
                                                       noise=True,
                                                       times=times)

        dataset_list.append(train_dataset)

        dataset = ConcatDataset(dataset_list)

        train_loader, valid_loader = datasets.split_dataset(
            dataset=dataset,
            batch_size=batch_size,
            val_ratio=val_ratio,
            shuffle=True)

        checkpoint.load_model(model)

    else:

        dataset_list = []

        for y in Y_list:

            df = Dataframe.get_pretrain_df()
            df[y] = df[y].rolling(window=window_size, min_periods=1).mean()

            dataset = datasets.CustomSequenceDataset(chunk_size=chunk_size,
                                                     df=df,
                                                     Y=y,
                                                     step_size=step_size,
                                                     noise=False,
                                                     times=1)

            dataset_list.append(dataset)

        dataset = ConcatDataset(dataset_list)

        train_loader, valid_loader = datasets.split_dataset(
            dataset=dataset,
            batch_size=batch_size,
            val_ratio=val_ratio,
            shuffle=True)

    optimizer = Adam(model.parameters(),
                     lr=lr,
                     weight_decay=float(weight_decay))

    if c_loss:
        criterion = custom_loss.mse_AIFrenz_torch
    else:
        criterion = nn.MSELoss()

    training_time = time.time()
    epoch = 0
    y_df = Dataframe.get_pretrain_df()[Y_list]
    y18_df = Dataframe.get_y18_df()[['Y18']]

    while epoch < EPOCH:

        print(f'\r Y: {Y} \
              chunk size: {chunk_size} \
              transfer: {transfer_learning}')

        epoch += 1
        train_loss_per_epoch, train_loss_list_per_batch, batch_list = train(
            model=model,
            train_loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            epoch=epoch,
            transfer_learning=transfer_learning,
            attention=attention,
            freeze_name='transfer_layer')

        valid_loss = valid(model=model,
                           valid_loader=valid_loader,
                           criterion=criterion,
                           attention=attention)

        iter_time = time.time() - training_time

        print(
            f'\r Epoch: {epoch:3d}/{str(EPOCH):3s}\t',
            f'train time: {int(iter_time//60):2d}m {iter_time%60:5.2f}s\t'
            f'avg train loss: {train_loss_per_epoch:7.3f}\t'
            f'valid loss: {valid_loss:7.3f}')

        checkpoint.save_log(batch_list, epoch, train_loss_list_per_batch,
                            train_loss_per_epoch, valid_loss)

        early_stop, is_best = early_stopping(valid_loss)
        checkpoint.save_checkpoint(model, optimizer, is_best)

        vis.print_training(EPOCH, epoch, training_time, train_loss_per_epoch,
                           valid_loss, patience, early_stopping.counter)
        vis.loss_plot(checkpoint)

        print('-----' * 17)

        y_true, y_pred, y_idx = predict.trainset_predict(
            model=model,
            data_dir=data_dir,
            Y=Y_list[0],
            chunk_size=chunk_size,
            attention=attention,
            window_size=window_size)

        y18_true, y18_pred, y18_idx = predict.trainset_predict(
            model=model,
            data_dir=data_dir,
            Y='Y18',
            chunk_size=chunk_size,
            attention=attention,
            window_size=window_size)

        y_df['pred'] = y_pred
        y18_df['pred'] = y18_pred

        vis.predict_plot(y_df, 'pre')
        vis.predict_plot(y18_df, 'trans')
        vis.print_error()

        if early_stop:

            break

    if transfer_learning:
        checkpoint.load_model(model, transfer_learningd=True)
    else:
        checkpoint.load_model(model, transfer_learningd=False)

    y_true, y_pred, y_idx = predict.trainset_predict(model=model,
                                                     data_dir=data_dir,
                                                     Y=Y_list[0],
                                                     chunk_size=chunk_size,
                                                     attention=attention,
                                                     window_size=window_size)

    y18_true, y18_pred, y18_idx = predict.trainset_predict(
        model=model,
        data_dir=data_dir,
        Y='Y18',
        chunk_size=chunk_size,
        attention=attention,
        window_size=window_size)

    y_df['pred'] = y_pred
    y18_df['pred'] = y18_pred

    vis.predict_plot(y_df, 'pre')
    vis.predict_plot(y18_df, 'trans')
    vis.print_error()

    if pred:

        predict.test_predict(model=model,
                             chunk_size=chunk_size,
                             filename=filename,
                             attention=attention)
Exemple #13
0
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistoryFace()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + '_val_log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    num_classes = opt.class_num

    if not opt.slurm:
        os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    layer_num = opt.layer_num
    order = opt.order
    net = create_cu_net(neck_size=4,
                        growth_rate=32,
                        init_chan_num=128,
                        class_num=num_classes,
                        layer_num=layer_num,
                        order=order,
                        loss_num=layer_num,
                        use_spatial_transformer=opt.stn,
                        mlp_tot_layers=opt.mlp_tot_layers,
                        mlp_hidden_units=opt.mlp_hidden_units,
                        get_mean_from_mlp=opt.get_mean_from_mlp)

    # Load the pre-trained model
    saved_wt_file = opt.saved_wt_file
    print("Loading weights from " + saved_wt_file)
    checkpoint_t = torch.load(saved_wt_file)
    state_dict = checkpoint_t['state_dict']

    for name, param in state_dict.items():
        name = name[7:]
        if name not in net.state_dict():
            print("=> not load weights '{}'".format(name))
            continue
        if isinstance(param, Parameter):
            param = param.data
        net.state_dict()[name].copy_(param)

    net = torch.nn.DataParallel(net).cuda()  # use multiple GPUs

    # Optimizer
    if opt.optimizer == "rmsprop":
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                               net.parameters()),
                                        lr=opt.lr,
                                        alpha=0.99,
                                        eps=1e-8,
                                        momentum=0,
                                        weight_decay=0)
    elif opt.optimizer == "adam":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            net.parameters()),
                                     lr=opt.lr)
    else:
        print("Unknown Optimizer. Aborting!!!")
        sys.exit(0)
    print(type(optimizer))

    # Optionally resume from a checkpoint
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print("Save prefix                           = {}".format(
        checkpoint.save_prefix))

    # Load data
    json_path = opt.json_path
    train_json = opt.train_json
    val_json = opt.val_json

    print("Path added to each image path in JSON = {}".format(json_path))
    print("Train JSON path                       = {}".format(train_json))
    print("Val JSON path                         = {}".format(val_json))

    # This train loader is useless
    train_loader = torch.utils.data.DataLoader(FACE(train_json,
                                                    json_path,
                                                    is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(FACE(val_json,
                                                  json_path,
                                                  is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_rmse, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer,
            num_classes, flip_index)
        checkpoint.save_preds(predictions)
        return

    global f_path
    global weights_HG

    f_path = exp_dir
    weights_HG = [float(x) for x in opt.hg_wt.split(",")]

    print("Postprocessing applied                = {}".format(opt.pp))
    if (opt.smax):
        print("Scaled softmax used with tau          = {}".format(opt.tau))
    else:
        print("No softmax used")

    if opt.is_covariance:
        print("Covariance used from the heatmap")
    else:
        print("Covariance calculated from MLP")

    print("Individual Hourglass loss weights")
    print(weights_HG)
    print("wt_MSE (tradeoff between GLL and MSE in each hourglass)= " +
          str(opt.wt_mse))
    print(
        "wt_gauss_regln (tradeoff between GLL and Gaussian Regularisation in each hourglass)= "
        + str(opt.wt_gauss_regln))

    # Optionally resume from a checkpoint
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    # Training and validation
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    train_loss_orig_epoch = []
    train_loss_gau_t1_epoch = []
    train_loss_gau_t2_epoch = []
    train_nme_orig_epoch = []
    train_nme_gau_epoch = []
    train_nme_new_epoch = []

    val_loss_orig_epoch = []
    val_loss_gau_t1_epoch = []
    val_loss_gau_t2_epoch = []
    val_nme_orig_epoch = []
    val_nme_gau_epoch = []
    val_nme_new_epoch = []

    for epoch in range(1):
        # Evaluate on validation set
        val_loss, val_loss_mse, val_loss_gau_t1, val_loss_gau_t2, val_rmse_orig, val_rmse_gau, val_rmse_new_box, predictions = validate(
            val_loader, net, epoch, visualizer, opt, num_classes, flip_index)
        val_loss_orig_epoch.append(val_loss_mse)
        val_loss_gau_t1_epoch.append(val_loss_gau_t1)
        val_loss_gau_t2_epoch.append(val_loss_gau_t2)
        val_nme_orig_epoch.append(val_rmse_orig)
        val_nme_gau_epoch.append(val_rmse_gau)
Exemple #14
0
    "patch_size": 8
})
Dis = Discriminator(args).to(device)
# summary(Dis,(3,32,32,))

optG = optim.AdamW(Gen.parameters(), lr=lr, betas=(beta1, beta2))
optD = optim.AdamW(Dis.parameters(), lr=lr, betas=(beta1, beta2))

img_list = []
G_losses = []
D_losses = []
loss_logs = {"gen_loss": [], "dis_loss": []}
iters = 0

ckp_class = Checkpoint(ckp_folder,
                       max_epochs=num_epochs,
                       num_ckps=5,
                       start_after=0.1)

# check if any existing checkpoint exists, none found hence start_epoch is 0.
# Optimizer states also get saved
Gen, Dis, optG, optD, start_epoch, old_logs = ckp_class.check_if_exists(
    Gen, Dis, optG, optD)

loss_logs = old_logs or loss_logs
print(start_epoch)  # , loss_logs

# Commented out IPython magic to ensure Python compatibility.
for epoch in range(start_epoch, num_epochs + 1):
    for i, data in enumerate(dataset.train_loader):

        ###########################
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(
        opt.exp_dir, opt.exp_id, opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_agent = Checkpoint()
    visualizer = Visualizer(opt)
    visualizer.log_path = sr_pretrain_dir + '/' + 'log.txt'
    train_scale_path = sr_pretrain_dir + '/' + 'train_scales.txt'
    train_rotation_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    val_scale_path = sr_pretrain_dir + '/' + 'val_scales.txt'
    val_rotation_path = sr_pretrain_dir + '/' + 'val_rotations.txt'

    # with open(visualizer.log_path, 'a+') as log_file:
    #     log_file.write(opt.resume_prefix_pose + '.pth.tar\n')
    # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.astn_dir, 'joint-count.txt')
    # print("=> log saved to path '{}'".format(visualizer.log_path))
    # if opt.dataset == 'mpii':
    #     num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    print 'collecting training scale and rotation distributions ...\n'
    train_scale_distri = read_grnd_distri_from_txt(train_scale_path)
    train_rotation_distri = read_grnd_distri_from_txt(train_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=True,
                   grnd_scale_distri=train_scale_distri,
                   grnd_rotation_distri=train_rotation_distri)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    print 'collecting validation scale and rotation distributions ...\n'
    val_scale_distri = read_grnd_distri_from_txt(val_scale_path)
    val_rotation_distri = read_grnd_distri_from_txt(val_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=False,
                   grnd_scale_distri=val_scale_distri,
                   grnd_rotation_distri=val_rotation_distri)
    val_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    agent = model.create_asn(chan_in=256,
                             chan_out=256,
                             scale_num=len(dataset.scale_means),
                             rotation_num=len(dataset.rotation_means),
                             is_aug=True)
    agent = torch.nn.DataParallel(agent).cuda()
    optimizer = torch.optim.RMSprop(agent.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    # optimizer = torch.optim.Adam(agent.parameters(), lr=opt.agent_lr)
    if opt.load_prefix_sr == '':
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/'
    else:
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/' + opt.load_prefix_sr
        checkpoint_agent.load_prefix = checkpoint_agent.save_prefix[0:-1]
        checkpoint_agent.load_checkpoint(agent, optimizer, train_history)
        # adjust_lr(optimizer, opt.lr)
        # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.asdn_dir, 'joint-count-finetune.txt')
    print 'agent: ', type(optimizer), optimizer.param_groups[0]['lr']

    if opt.dataset == 'mpii':
        num_classes = 16
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    checkpoint_hg = Checkpoint()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    logger = Logger(sr_pretrain_dir + '/' + 'training-summary.txt',
                    title='training-summary')
    logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss'])
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_sr != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        # train for one epoch
        train_loss = train(train_loader, hg, agent, optimizer, epoch,
                           visualizer, opt)
        val_loss = validate(val_loader, hg, agent, epoch, visualizer, opt)
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        # pckh = OrderedDict( [('val_pckh', val_pckh)] )
        train_history.update(e, lr, loss)
        # print(train_history.lr[-1]['lr'])
        checkpoint_agent.save_checkpoint(agent,
                                         optimizer,
                                         train_history,
                                         is_asn=True)
        visualizer.plot_train_history(train_history, 'sr')
        logger.append(
            [epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss])
    logger.close()
def main():
    # global args, best_prec1
    args = parser.parse_args()
    print('\n====> Input Arguments')
    print(args)
    # Tensorboard writer.
    global writer
    writer = SummaryWriter(log_dir=args.result_path)

    # Create dataloader.
    print '\n====> Creating dataloader...'
    train_loader = get_train_loader(args)
    test_loader = get_test_loader(args)

    # Load Resnet_a network.
    print '====> Loading the network...'
    model = Inception_a(num_class=args.num_class,
                        num_frame=args.num_frame,
                        pretrained=True)

    # Load single frame pretrain.
    pretrain_model = torch.load('models/pretrain_inc_sf.pth')
    keys = model.state_dict().keys()
    new_state_dict = {}
    for i, k in enumerate(pretrain_model.keys()):
        new_state_dict[keys[i]] = pretrain_model[k]
    model.load_state_dict(new_state_dict)
    """Load checkpoint and weight of network.
	"""
    global cp_recorder
    if args.checkpoint_dir:
        cp_recorder = Checkpoint(args.checkpoint_dir, args.checkpoint_name)
        cp_recorder.load_checkpoint(model)

    model = nn.DataParallel(model)
    model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    # optimizer = torch.optim.SGD(model.module.classifier.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.wd,
                                momentum=args.momentum)

    # Train Resnet_a model.
    print '====> Training...'
    for epoch in range(cp_recorder.contextual['b_epoch'], args.epoch):
        _, _, prec_tri, rec_tri, ap_tri = train_eval(train_loader, test_loader,
                                                     model, criterion,
                                                     optimizer, args, epoch)
        top1_avg_val, loss_avg_val, prec_val, rec_val, ap_val = validate_eval(
            test_loader, model, criterion, args, epoch)

        # Print result.
        writer.add_scalars('mAP (per epoch)',
                           {'train': np.nan_to_num(ap_tri).mean()}, epoch)
        writer.add_scalars('mAP (per epoch)',
                           {'valid': np.nan_to_num(ap_val).mean()}, epoch)

        print_score = False
        if print_score:
            print('\n====> Scores')
            print(
                '[Epoch {0}]:\n'
                '  Train:\n'
                '    Prec@1 {1}\n'
                '    Recall {2}\n'
                '    AP {3}\n'
                '    mAP {4:.3f}\n'
                '  Valid:\n'
                '    Prec@1 {5}\n'
                '    Recall {6}\n'
                '    AP {7}\n'
                '    mAP {8:.3f}\n'.format(epoch, prec_tri, rec_tri, ap_tri,
                                           np.nan_to_num(ap_tri).mean(),
                                           prec_val, rec_val, ap_val,
                                           np.nan_to_num(ap_val).mean()))

        # Record.
        writer.add_scalars('Loss (per batch)', {'valid': loss_avg_val},
                           (epoch + 1) * len(train_loader))
        writer.add_scalars('Prec@1 (per batch)', {'valid': top1_avg_val},
                           (epoch + 1) * len(train_loader))
        writer.add_scalars('mAP (per batch)',
                           {'valid': np.nan_to_num(ap_val).mean()},
                           (epoch + 1) * len(train_loader))

        # Save checkpoint.
        cp_recorder.record_contextual({
            'b_epoch': epoch + 1,
            'b_batch': -1,
            'prec': top1_avg_val,
            'loss': loss_avg_val,
            'class_prec': prec_val,
            'class_recall': rec_val,
            'class_ap': ap_val,
            'mAP': np.nan_to_num(ap_val).mean()
        })
        cp_recorder.save_checkpoint(model)
Exemple #17
0
    if arguments.fix_a is None and arguments.reg_type == "swd" and arguments.pruning_iterations != 1:
        print('Progressive a is not compatible with iterative pruning')
        raise ValueError
    if arguments.no_ft and arguments.pruning_iterations != 1:
        print("You can't specify a pruning_iteration value if there is no fine-tuning at all")
        raise ValueError
    get_mask = get_mask_function(arguments.pruning_type)
    _dataset = get_dataset(arguments)
    _targets = [int((n + 1) * (arguments.target / arguments.pruning_iterations)) for n in
                range(arguments.pruning_iterations)]

    # Train model
    print('Train model !')
    print(f'Regularization with t-{_targets[0]}')

    training_model = Checkpoint(arguments, 'training')
    training_model.regularization = Regularization(None, _targets[0], arguments)
    training_model.load()
    train_model(training_model, arguments, [0, arguments.epochs], _dataset, None, soft_pruning=arguments.soft_pruning)

    if arguments.lr_rewinding:
        training_model.rewind_lr()

    if arguments.no_ft:
        print('\nPruning model without fine tuning :')
        pruned_model = training_model.clone('pruned')
        pruned_model.load()
        mask = get_mask(pruned_model.model, arguments.target)
        apply_mask(pruned_model.model, mask)
        _acc, _top5, _test_loss = test_model(_dataset, pruned_model.model, arguments)
        pruned_model.save_results({'epoch': 'before', 'acc': _acc, 'top5': _top5, 'loss': _test_loss,
def main():
    opt = TrainOptions().parse()
    if opt.joint_dir == '':
        print('joint directory is null.')
        exit()
    joint_dir = os.path.join(opt.exp_dir, opt.exp_id,
                             opt.joint_dir + '-' + opt.load_prefix_pose[0:-1])
    # joint_dir = os.path.join(opt.exp_dir, opt.exp_id,
    #                          opt.joint_dir)
    if not os.path.isdir(joint_dir):
        os.makedirs(joint_dir)

    visualizer = Visualizer(opt)
    visualizer.log_path = joint_dir + '/' + 'train-log.txt'

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id,
    #                                      opt.joint_dir, 'joint-count.txt')
    if opt.dataset == 'mpii':
        num_classes = 16
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    """optimizer"""
    optimizer_hg = torch.optim.RMSprop(hg.parameters(),
                                       lr=opt.lr,
                                       alpha=0.99,
                                       eps=1e-8,
                                       momentum=0,
                                       weight_decay=0)
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    train_history_pose = PoseTrainHistory()
    checkpoint_hg = Checkpoint()
    if opt.load_checkpoint:
        checkpoint_hg.load_prefix = joint_dir + '/' + opt.load_prefix_pose[0:-1]
        checkpoint_hg.load_checkpoint(hg, optimizer_hg, train_history_pose)
    else:
        checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id) + \
                                    '/' + opt.load_prefix_pose[0:-1]
        checkpoint_hg.load_checkpoint(hg, optimizer_hg, train_history_pose)
        for param_group in optimizer_hg.param_groups:
            param_group['lr'] = opt.lr
    checkpoint_hg.save_prefix = joint_dir + '/pose-'
    # trunc_index = checkpoint.save_prefix_pose.index('lr-0.00025-85')
    # checkpoint.save_prefix_pose = checkpoint.save_prefix_pose[0:trunc_index]
    # print(checkpoint.save_prefix_pose)
    print 'hg optimizer: ', type(
        optimizer_hg), optimizer_hg.param_groups[0]['lr']

    agent_sr = model.create_asn(chan_in=256,
                                chan_out=256,
                                scale_num=len(dataset.scale_means),
                                rotation_num=len(dataset.rotation_means),
                                is_aug=True)
    agent_sr = torch.nn.DataParallel(agent_sr).cuda()
    optimizer_sr = torch.optim.RMSprop(agent_sr.parameters(),
                                       lr=opt.agent_lr,
                                       alpha=0.99,
                                       eps=1e-8,
                                       momentum=0,
                                       weight_decay=0)
    if opt.load_prefix_sr == '':
        print('please input the checkpoint name of the sr agent.')
        exit()
    train_history_sr = ASNTrainHistory()
    checkpoint_sr = Checkpoint()
    if opt.load_checkpoint:
        checkpoint_sr.load_prefix = joint_dir + '/' + opt.load_prefix_sr[0:-1]
        checkpoint_sr.load_checkpoint(agent_sr, optimizer_sr, train_history_sr)
    else:
        sr_pretrain_dir = os.path.join(
            opt.exp_dir, opt.exp_id,
            opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
        checkpoint_sr.load_prefix = sr_pretrain_dir + '/' + opt.load_prefix_sr[
            0:-1]
        checkpoint_sr.load_checkpoint(agent_sr, optimizer_sr, train_history_sr)
        for param_group in optimizer_sr.param_groups:
            param_group['lr'] = opt.agent_lr
    checkpoint_sr.save_prefix = joint_dir + '/agent-'
    # trunc_index = checkpoint.save_prefix_asn.index('lr-0.00025-80')
    # checkpoint.save_prefix_asn = checkpoint.save_prefix_asn[0:trunc_index]
    # print(checkpoint.save_prefix_asn)
    # adjust_lr(optimizer_asn, 5e-5)
    print 'agent optimizer: ', type(
        optimizer_sr), optimizer_sr.param_groups[0]['lr']

    train_dataset_hg = MPII('dataset/mpii-hr-lsp-normalizer.json',
                            '/bigdata1/zt53/data',
                            is_train=True)
    train_loader_hg = torch.utils.data.DataLoader(train_dataset_hg,
                                                  batch_size=opt.bs,
                                                  shuffle=True,
                                                  num_workers=opt.nThreads,
                                                  pin_memory=True)
    val_dataset_hg = MPII('dataset/mpii-hr-lsp-normalizer.json',
                          '/bigdata1/zt53/data',
                          is_train=False)
    val_loader_hg = torch.utils.data.DataLoader(val_dataset_hg,
                                                batch_size=opt.bs,
                                                shuffle=False,
                                                num_workers=opt.nThreads,
                                                pin_memory=True)
    train_dataset_agent = AGENT('dataset/mpii-hr-lsp-normalizer.json',
                                '/bigdata1/zt53/data',
                                separate_s_r=True)
    train_loader_agent = torch.utils.data.DataLoader(train_dataset_agent,
                                                     batch_size=opt.bs,
                                                     shuffle=True,
                                                     num_workers=opt.nThreads,
                                                     pin_memory=True)

    # idx = range(0, 16)
    # idx_pckh = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    if not opt.is_train:
        visualizer.log_path = joint_dir + '/' + 'val-log.txt'
        val_loss, val_pckh, predictions = validate(
            val_loader_hg, hg, train_history_pose.epoch[-1]['epoch'],
            visualizer, num_classes)
        checkpoint_hg.save_preds(predictions)
        return
    logger = Logger(joint_dir + '/' + 'pose-training-summary.txt',
                    title='pose-training-summary')
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train PCKh', 'Val PCKh'])
    """training and validation"""
    start_epoch_pose = train_history_pose.epoch[-1]['epoch'] + 1
    epoch_sr = train_history_sr.epoch[-1]['epoch'] + 1

    for epoch in range(start_epoch_pose, opt.nEpochs):
        adjust_lr(opt, optimizer_hg, epoch)
        # train hg for one epoch
        train_loss_pose, train_pckh = train_hg(train_loader_hg, hg,
                                               optimizer_hg, agent_sr, epoch,
                                               visualizer, opt)
        # util.save_drop_count(drop_count, lost_joint_count_path)
        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader_hg, hg, epoch,
                                                   visualizer, num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e_pose = OrderedDict([('epoch', epoch)])
        lr_pose = OrderedDict([('lr', optimizer_hg.param_groups[0]['lr'])])
        loss_pose = OrderedDict([('train_loss', train_loss_pose),
                                 ('val_loss', val_loss)])
        pckh = OrderedDict([('train_pckh', train_pckh),
                            ('val_pckh', val_pckh)])
        train_history_pose.update(e_pose, lr_pose, loss_pose, pckh)
        checkpoint_hg.save_checkpoint(hg, optimizer_hg, train_history_pose,
                                      predictions)
        visualizer.plot_train_history(train_history_pose)
        logger.append([
            epoch, optimizer_hg.param_groups[0]['lr'], train_loss_pose,
            val_loss, train_pckh, val_pckh
        ])
        # exit()
        # if train_history_pose.is_best:
        #     visualizer.display_imgpts(imgs, pred_pts, 4)

        # train agent_sr for one epoch
        train_loss_sr = train_agent_sr(train_loader_agent, hg, agent_sr,
                                       optimizer_sr, epoch_sr, visualizer, opt)
        e_sr = OrderedDict([('epoch', epoch_sr)])
        lr_sr = OrderedDict([('lr', optimizer_sr.param_groups[0]['lr'])])
        loss_sr = OrderedDict([('train_loss', train_loss_sr), ('val_loss', 0)])
        train_history_sr.update(e_sr, lr_sr, loss_sr)
        # print(train_history.lr[-1]['lr'])
        checkpoint_sr.save_checkpoint(agent_sr,
                                      optimizer_sr,
                                      train_history_sr,
                                      is_asn=True)
        visualizer.plot_train_history(train_history_sr, 'sr')
        # exit()
        epoch_sr += 1

    logger.close()
Exemple #19
0
def main():
    opt = TrainOptions().parse()
    train_history = PoseTrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_path = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    net = create_hg(num_stacks=2,
                    num_modules=1,
                    num_classes=num_classes,
                    chan=256)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print('number of params: ', num1)
    # print('number of trainalbe params: ', num2)
    # print('number of conv params: ', num3)
    # exit()
    net = torch.nn.DataParallel(net).cuda()
    """optimizer"""
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.load_prefix_pose != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.load_prefix_pose)
        checkpoint.load_prefix = os.path.join(exp_dir,
                                              opt.load_prefix_pose)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        # trunc_index = checkpoint.save_prefix.index('lr-0.00025-80')
        # checkpoint.save_prefix = checkpoint.save_prefix[0:trunc_index]
        # checkpoint.save_prefix = exp_dir + '/'
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print('save prefix: ', checkpoint.save_prefix)
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    print(type(optimizer), optimizer.param_groups[0]['lr'])
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    # criterion = torch.nn.MSELoss(size_average=True).cuda()
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_pose != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('train_pckh', train_pckh),
                            ('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        visualizer.plot_train_history(train_history)
def main():
    # global args, best_prec1
    args = parser.parse_args()
    print('\n====> Input Arguments')
    print(args)
    # Tensorboard writer.
    global writer
    writer = SummaryWriter(log_dir=args.result_path)

    # Create dataloader.
    print '\n====> Creating dataloader...'
    train_loader = get_train_loader(args)
    test_loader = get_test_loader(args)

    # Load GRM network.
    print '====> Loading the GRM network...'
    model = GRM(num_classes=args.num_classes,
                adjacency_matrix=args.adjacency_matrix)

    # Load First-Glance network.
    print '====> Loading the finetune First Glance model...'
    if args.fg_finetune and os.path.isfile(args.fg_finetune):
        model.fg.load_state_dict(torch.load(args.fg_finetune))
    else:
        print("No find '{}'".format(args.fg_finetune))

    # Load checkpoint and weight of network.
    global cp_recorder
    if args.checkpoint_dir:
        cp_recorder = Checkpoint(args.checkpoint_dir, args.checkpoint_name)
        cp_recorder.load_checkpoint(model)

    model.fg = torch.nn.DataParallel(model.fg)
    model.full_im_net = torch.nn.DataParallel(model.full_im_net)
    model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True
    optimizer_cls = torch.optim.SGD(model.classifier.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.wd,
                                    momentum=args.momentum)
    optimizer_ggnn = torch.optim.Adam(model.ggnn.parameters(),
                                      lr=0.00001,
                                      weight_decay=args.wd)

    # Train GRM model.
    print '====> Training...'
    for epoch in range(cp_recorder.contextual['b_epoch'], args.epoch):
        _, _, prec_tri, rec_tri, ap_tri = train_eval(train_loader, test_loader,
                                                     model, criterion,
                                                     optimizer_cls,
                                                     optimizer_ggnn, args,
                                                     epoch)
        top1_avg_val, loss_avg_val, prec_val, rec_val, ap_val = validate_eval(
            test_loader, model, criterion, args, epoch)

        # Print result.
        writer.add_scalars('mAP (per epoch)',
                           {'train': np.nan_to_num(ap_tri).mean()}, epoch)
        writer.add_scalars('mAP (per epoch)',
                           {'valid': np.nan_to_num(ap_val).mean()}, epoch)
        print('\n====> Scores')
        print(
            '[Epoch {0}]:\n'
            '  Train:\n'
            '    Prec@1 {1}\n'
            '    Recall {2}\n'
            '    AP {3}\n'
            '    mAP {4:.3f}\n'
            '  Valid:\n'
            '    Prec@1 {5}\n'
            '    Recall {6}\n'
            '    AP {7}\n'
            '    mAP {8:.3f}\n'.format(epoch, prec_tri, rec_tri, ap_tri,
                                       np.nan_to_num(ap_tri).mean(), prec_val,
                                       rec_val, ap_val,
                                       np.nan_to_num(ap_val).mean()))

        # Record.
        writer.add_scalars('Loss (per batch)', {'valid': loss_avg_val},
                           (epoch + 1) * len(train_loader))
        writer.add_scalars('Prec@1 (per batch)', {'valid': top1_avg_val},
                           (epoch + 1) * len(train_loader))
        writer.add_scalars('mAP (per batch)',
                           {'valid': np.nan_to_num(ap_val).mean()},
                           (epoch + 1) * len(train_loader))

        # Save checkpoint.
        cp_recorder.record_contextual({
            'b_epoch': epoch + 1,
            'b_batch': -1,
            'prec': top1_avg_val,
            'loss': loss_avg_val,
            'class_prec': prec_val,
            'class_recall': rec_val,
            'class_ap': ap_val,
            'mAP': np.nan_to_num(ap_val).mean()
        })
        cp_recorder.save_checkpoint(model)