Exemple #1
0
    log_file = os.path.join(log_dir, opt.version + '.txt')
    with open(log_file, 'a') as f:
        f.write(str(opt) + '\n')
        f.flush()

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus
    cudnn.benchmark = True

    data = Data()
    model = build_model(opt, data.num_classes)
    optimizer = make_optimizer(opt, model)
    loss = make_loss(opt, data.num_classes)

    # WARMUP_FACTOR: 0.01
    # WARMUP_ITERS: 10
    scheduler = WarmupMultiStepLR(optimizer, opt.steps, 0.1, 0.01, 10,
                                  "linear")
    main = Main(opt, model, data, optimizer, scheduler, loss)

    if opt.mode == 'train':

        # 总迭代次数
        epoch = 200
        start_epoch = 1

        # 断点加载训练
        if opt.resume:
            ckpt = torch.load(opt.resume)
            start_epoch = ckpt['epoch']
            logger.info('resume from the epoch: ', start_epoch)
            model.load_state_dict(ckpt['state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
Exemple #2
0
def main(args):
    if args.output_dir:
        makedir(args.output_dir)
    print(args)
    device = torch.device(args.device)

    transform_train = VideoClassificationPresetTrain((128, 171), (112, 112))
    dataset_train = VideoDatasetCustom(args.train_dir,
                                       "annotations.txt",
                                       transform=transform_train)

    transform_eval = VideoClassificationPresetEval((128, 171), (112, 112))
    dataset_eval = VideoDatasetCustom(args.val_dir,
                                      "annotations.txt",
                                      transform=transform_eval)

    train_sampler = RandomClipSampler(dataset_train.clips,
                                      args.clips_per_video)
    test_sampler = UniformClipSampler(dataset_eval.clips, args.clips_per_video)

    data_loader = torch.utils.data.DataLoader(dataset_train,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)

    data_loader_eval = torch.utils.data.DataLoader(dataset_eval,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   collate_fn=collate_fn)

    model = torchvision.models.video.__dict__[args.model](
        pretrained=args.pretrained)
    model.to(device)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    warmup_iters = args.lr_warmup_epochs * len(data_loader)
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     milestones=lr_milestones,
                                     gamma=args.lr_gamma,
                                     warmup_iters=warmup_iters,
                                     warmup_factor=1e-5)

    print("Start training")
    writer = SummaryWriter('runs/vc_experiment_1')
    start_time = time.time()
    for epoch in range(args.epochs):
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                        device, epoch, args.print_freq, writer)
        evaluate(model, criterion, data_loader_eval, device=device)

        if args.output_dir:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            print("Saving checkpoint to {}".format(
                os.path.join(args.output_dir, 'checkpoint.pth')))
            torch.save(checkpoint,
                       os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
        for state in optimizer_model_rpn.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        for state in optimizer_classifier.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()


model_rpn_cuda = f'cuda:{model_rpn.device_ids[0]}'
model_classifier_cuda = f'cuda:{model_classifier.device_ids[0]}'


scheduler_rpn = WarmupMultiStepLR(optimizer_model_rpn, milestones=[40, 70], gamma=args.gamma, warmup_factor=0.01, warmup_iters=10)
scheduler_class = WarmupMultiStepLR(optimizer_classifier, milestones=[40, 70], gamma=args.gamma, warmup_factor=0.01, warmup_iters=10)

all_possible_anchor_boxes = default_anchors(out_h=out_h, out_w=out_w, anchor_sizes=anchor_sizes , anchor_ratios=anchor_ratios , downscale=16)
all_possible_anchor_boxes_tensor = torch.tensor(all_possible_anchor_boxes).to(device=device)


def train(epoch):
    print("\n\nTraining epoch {}\n\n".format(epoch))
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    
    regr_rpn_loss= 0 
    class_rpn_loss =0 
    total_rpn_loss = 0 
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    #print("Initializing dataset {}".format(args.dataset))
    # dataset = data_manager.init_dataset(name=args.dataset)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    transform_train_p = T.Compose([
        T.Random2DTranslation(256, 128),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test_p = T.Compose([
        T.Resize((256, 128)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    train_file = 'data/cuhk_train.pkl'
    test_file = 'data/cuhk_test.pkl'
    gallery_file = 'data/cuhk_gallery.pkl'
    data_root = args.data_root
    dataset_train = CUHKGroup(train_file, data_root, True, transform_train,
                              transform_train_p)
    dataset_test = CUHKGroup(test_file, data_root, False, transform_test,
                             transform_test_p)
    dataset_query = CUHKGroup(test_file, data_root, False, transform_test,
                              transform_test_p)
    dataset_gallery = CUHKGroup(gallery_file, data_root, False, transform_test,
                                transform_test_p)

    pin_memory = True if use_gpu else False

    if args.xent_only:
        trainloader = DataLoader(
            dataset_train,
            batch_size=args.train_batch,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=True,
        )
    else:
        trainloader = DataLoader(
            dataset_train,
            sampler=RandomIdentitySampler(dataset_train,
                                          num_instances=args.num_instances),
            batch_size=args.train_batch,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=True,
        )

    queryloader = DataLoader(
        dataset_test,
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    querygalleryloader = DataLoader(
        dataset_query,
        batch_size=args.gallery_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    galleryloader = DataLoader(
        dataset_gallery,
        batch_size=args.gallery_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    print("Initializing model: {}".format(args.arch))
    if args.xent_only:
        # model = models.init_model(name=args.arch, num_classes=dataset_train.num_train_gids, loss={'xent'})
        model = models.init_model(name=args.arch,
                                  num_classes=dataset_train.num_train_gids,
                                  loss={'xent'})
    else:
        # model = models.init_model(name=args.arch, num_classes=dataset_train.num_train_gids, loss={'xent', 'htri'})
        model = models.init_model(
            name=args.arch,
            num_classes=dataset_train.num_train_gids,
            num_person_classes=dataset_train.num_train_pids,
            loss={'xent', 'htri'})

    #criterion_xent = CrossEntropyLabelSmooth(num_classes=dataset_train.num_train_gids, use_gpu=use_gpu)
    #criterion_xent_person = CrossEntropyLabelSmooth(num_classes=dataset_train.num_train_pids, use_gpu=use_gpu)

    if os.path.exists(args.pretrained_model):
        print("Loading checkpoint from '{}'".format(args.pretrained_model))
        checkpoint = torch.load(args.pretrained_model)
        model_dict = model.state_dict()
        pretrain_dict = checkpoint['state_dict']
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items() if k in model_dict
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)

    criterion_xent = nn.CrossEntropyLoss(ignore_index=-1)
    criterion_xent_person = nn.CrossEntropyLoss(ignore_index=-1)
    criterion_htri = TripletLoss(margin=args.margin)
    criterion_pair = ContrastiveLoss(margin=args.margin)
    criterion_htri_filter = TripletLossFilter(margin=args.margin)
    criterion_permutation = PermutationLoss()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    if args.stepsize > 0:
        if args.warmup:
            scheduler = WarmupMultiStepLR(optimizer, [200, 400, 600])
        else:
            scheduler = lr_scheduler.StepLR(optimizer,
                                            step_size=args.stepsize,
                                            gamma=args.gamma)
    start_epoch = args.start_epoch

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        test_gcn_person_batch(model, queryloader, querygalleryloader,
                              galleryloader, args.pool, use_gpu)
        #test_gcn_batch(model, queryloader, querygalleryloader, galleryloader, args.pool, use_gpu)
        #test_gcn(model, queryloader, galleryloader, args.pool, use_gpu)
        #test(model, queryloader, galleryloader, args.pool, use_gpu)
        return

    start_time = time.time()
    best_rank1 = -np.inf
    for epoch in range(start_epoch, args.max_epoch):
        #print("==> Epoch {}/{}  lr:{}".format(epoch + 1, args.max_epoch, scheduler.get_lr()[0]))

        train_gcn(model, criterion_xent, criterion_xent_person, criterion_pair,
                  criterion_htri_filter, criterion_htri, criterion_permutation,
                  optimizer, trainloader, use_gpu)
        #train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)

        if args.stepsize > 0: scheduler.step()

        if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                epoch + 1) == args.max_epoch:
            print("==> Test")
            rank1 = test_gcn_person_batch(model, queryloader,
                                          querygalleryloader, galleryloader,
                                          args.pool, use_gpu)
            #rank1 = test_gcn(model, queryloader, galleryloader, args.pool, use_gpu=False)
            #rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
            is_best = rank1 > best_rank1
            if is_best: best_rank1 = rank1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
Exemple #5
0
    loss_function = CrossEntropyLabelSmooth(num_classes=35)

    # optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    # train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2)   #learning rate decay
    # iter_per_epoch = len(tiger_training_loader)
    # warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    #原来的lr策略
    optimizer = make_optimizer(args, net)
    warmup_scheduler = WarmupMultiStepLR(
        optimizer,
        settings.MILESTONES,
        gamma=0.5,  #0.1, 0.5
        warmup_factor=1.0 / 3,
        warmup_iters=0,
        warmup_method="linear",
        last_epoch=-1,
    )

    #cycle lr
    # optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    # warmup_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer,
    #     T_max = 10,
    #     eta_min = 0.000001
    # )

    checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net,
                                   settings.TIME_NOW)
Exemple #6
0
    def __init__(self):
        self.opts = options.parse()
        self.save_dir = os.path.join(self.opts.save_dir, self.opts.model_name)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        # set a logger
        logging.basicConfig(filename=os.path.join(self.save_dir,
                                                  "log-train.log"),
                            format='%(asctime)s %(message)s',
                            datefmt='%m/%d/%Y %p %I:%M:%S',
                            level=logging.INFO)
        logging.getLogger().setLevel(logging.INFO)
        self.logger = logging.getLogger("trainLogger")
        self.logger.addHandler(logging.StreamHandler(sys.stdout))

        # set device for training/validation
        use_cuda = torch.cuda.is_available()  # check if GPU exists
        self.device = torch.device(
            "cuda" if use_cuda else "cpu")  # use CPU or GPU

        # train, val, test datasets
        data_train = AirTypingDataset(self.opts, self.opts.data_path_train)
        data_val = AirTypingDataset(self.opts, self.opts.data_path_val)
        data_test = AirTypingDataset(self.opts, self.opts.data_path_test)
        self.data_loader_train = DataLoader(data_train,
                                            batch_size=self.opts.batch_size,
                                            shuffle=True,
                                            num_workers=self.opts.num_workers,
                                            collate_fn=pad_collate,
                                            drop_last=True)
        self.data_loader_val = DataLoader(data_val,
                                          batch_size=self.opts.batch_size,
                                          shuffle=False,
                                          num_workers=self.opts.num_workers,
                                          collate_fn=pad_collate,
                                          drop_last=True)
        self.data_loader_test = DataLoader(data_test,
                                           batch_size=self.opts.batch_size,
                                           shuffle=False,
                                           num_workers=self.opts.num_workers,
                                           collate_fn=pad_collate,
                                           drop_last=True)

        # create a model
        self.model = GestureTranslator(self.opts).to(self.device)

        # parallelize the model to multiple GPUs
        if torch.cuda.device_count() > 1:
            self.logger.info("We're Using {} GPUs!".format(
                torch.cuda.device_count()))
            self.model = torch.nn.DataParallel(self.model)
            self.model_without_dp = self.model.module
        elif torch.cuda.device_count() == 1:
            self.logger.info("We're Using {} GPU!".format(
                torch.cuda.device_count()))
            self.model_without_dp = self.model

        # define an optimizer
        self.params_to_train = list(self.model_without_dp.parameters())
        if self.opts.optimizer_type == 'rmsprop':
            self.optimizer = torch.optim.RMSprop(self.params_to_train,
                                                 lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'sgd':
            self.optimizer = torch.optim.SGD(self.params_to_train,
                                             lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'adam':
            self.optimizer = torch.optim.Adam(self.params_to_train,
                                              lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'adamW':
            self.optimizer = torch.optim.AdamW(self.params_to_train,
                                               lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'lamb':
            self.optimizer = Lamb(self.params_to_train,
                                  lr=self.opts.learning_rate)

        # define a scheduler
        iter_size = len(self.data_loader_train)
        scheduler_step_size = self.opts.scheduler_step_size * iter_size
        if self.opts.scheduler_type == 'warmup':
            self.scheduler = WarmupMultiStepLR(self.optimizer, [
                scheduler_step_size * (i + 1)
                for i in range(self.opts.num_epochs)
            ],
                                               gamma=self.opts.scheduler_gamma,
                                               warmup_iters=500)
        elif self.opts.scheduler_type == 'steplr':
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=scheduler_step_size,
                gamma=self.opts.scheduler_gamma)
        else:
            self.scheduler = None

        # load pretrained weights if available
        if self.opts.load_dir:
            self.load_model()

        # define a loss
        self.ctc_loss = torch.nn.CTCLoss(reduction='mean', zero_infinity=True)

        # set train variables and save options
        self.epoch = 0
        self.step = 0
        self.start_step = 0
        self.start_time = time.time()
        self.num_total_steps = iter_size * self.opts.num_epochs
        self.best_val_loss = float('Inf')
        self.best_val_cer = float('Inf')

        self.save_opts()
Exemple #7
0
class Trainer:
    def __init__(self):
        self.opts = options.parse()
        self.save_dir = os.path.join(self.opts.save_dir, self.opts.model_name)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        # set a logger
        logging.basicConfig(filename=os.path.join(self.save_dir,
                                                  "log-train.log"),
                            format='%(asctime)s %(message)s',
                            datefmt='%m/%d/%Y %p %I:%M:%S',
                            level=logging.INFO)
        logging.getLogger().setLevel(logging.INFO)
        self.logger = logging.getLogger("trainLogger")
        self.logger.addHandler(logging.StreamHandler(sys.stdout))

        # set device for training/validation
        use_cuda = torch.cuda.is_available()  # check if GPU exists
        self.device = torch.device(
            "cuda" if use_cuda else "cpu")  # use CPU or GPU

        # train, val, test datasets
        data_train = AirTypingDataset(self.opts, self.opts.data_path_train)
        data_val = AirTypingDataset(self.opts, self.opts.data_path_val)
        data_test = AirTypingDataset(self.opts, self.opts.data_path_test)
        self.data_loader_train = DataLoader(data_train,
                                            batch_size=self.opts.batch_size,
                                            shuffle=True,
                                            num_workers=self.opts.num_workers,
                                            collate_fn=pad_collate,
                                            drop_last=True)
        self.data_loader_val = DataLoader(data_val,
                                          batch_size=self.opts.batch_size,
                                          shuffle=False,
                                          num_workers=self.opts.num_workers,
                                          collate_fn=pad_collate,
                                          drop_last=True)
        self.data_loader_test = DataLoader(data_test,
                                           batch_size=self.opts.batch_size,
                                           shuffle=False,
                                           num_workers=self.opts.num_workers,
                                           collate_fn=pad_collate,
                                           drop_last=True)

        # create a model
        self.model = GestureTranslator(self.opts).to(self.device)

        # parallelize the model to multiple GPUs
        if torch.cuda.device_count() > 1:
            self.logger.info("We're Using {} GPUs!".format(
                torch.cuda.device_count()))
            self.model = torch.nn.DataParallel(self.model)
            self.model_without_dp = self.model.module
        elif torch.cuda.device_count() == 1:
            self.logger.info("We're Using {} GPU!".format(
                torch.cuda.device_count()))
            self.model_without_dp = self.model

        # define an optimizer
        self.params_to_train = list(self.model_without_dp.parameters())
        if self.opts.optimizer_type == 'rmsprop':
            self.optimizer = torch.optim.RMSprop(self.params_to_train,
                                                 lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'sgd':
            self.optimizer = torch.optim.SGD(self.params_to_train,
                                             lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'adam':
            self.optimizer = torch.optim.Adam(self.params_to_train,
                                              lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'adamW':
            self.optimizer = torch.optim.AdamW(self.params_to_train,
                                               lr=self.opts.learning_rate)
        elif self.opts.optimizer_type == 'lamb':
            self.optimizer = Lamb(self.params_to_train,
                                  lr=self.opts.learning_rate)

        # define a scheduler
        iter_size = len(self.data_loader_train)
        scheduler_step_size = self.opts.scheduler_step_size * iter_size
        if self.opts.scheduler_type == 'warmup':
            self.scheduler = WarmupMultiStepLR(self.optimizer, [
                scheduler_step_size * (i + 1)
                for i in range(self.opts.num_epochs)
            ],
                                               gamma=self.opts.scheduler_gamma,
                                               warmup_iters=500)
        elif self.opts.scheduler_type == 'steplr':
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=scheduler_step_size,
                gamma=self.opts.scheduler_gamma)
        else:
            self.scheduler = None

        # load pretrained weights if available
        if self.opts.load_dir:
            self.load_model()

        # define a loss
        self.ctc_loss = torch.nn.CTCLoss(reduction='mean', zero_infinity=True)

        # set train variables and save options
        self.epoch = 0
        self.step = 0
        self.start_step = 0
        self.start_time = time.time()
        self.num_total_steps = iter_size * self.opts.num_epochs
        self.best_val_loss = float('Inf')
        self.best_val_cer = float('Inf')

        self.save_opts()

    def train(self):
        for self.epoch in range(self.opts.num_epochs):
            print("epoch: ", self.epoch)
            self.model.train()
            self.run_one_epoch()
            is_best = self.validate(self.data_loader_val, False)
            if is_best or (self.opts.save_frequency > 0 and
                           (self.epoch + 1) % self.opts.save_frequency == 0):
                self.save_model(is_best)

    def run_one_epoch(self):
        losses = []
        for batch_idx, (xx_pad, yy_pad, x_lens,
                        y_lens) in enumerate(self.data_loader_train):
            time_before_process = time.time()
            # distribute data to device
            xx_pad, yy_pad = xx_pad.to(self.device), yy_pad.to(self.device)
            x_lens, y_lens = x_lens.to(self.device), y_lens.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(xx_pad, x_lens)
            output = output.permute(1, 0, 2).log_softmax(2)
            loss = self.ctc_loss(output, yy_pad, x_lens, y_lens)
            y_pred = torch.max(output, 2)[1]
            losses.append(loss.item())
            loss.backward()
            self.optimizer.step()

            if self.scheduler:
                self.scheduler.step()

            duration = abs(time_before_process - time.time())
            # log information
            if (batch_idx + 1) % self.opts.log_interval == 0:
                self.log_time(duration, batch_idx, loss.item())
                self.logger.info('\tGround Truth: {}'.format(
                    self.data_loader_train.dataset.converter.decode(
                        yy_pad[0, :y_lens[0]], y_lens[:1])))
                self.logger.info('\tModel Output: {}'.format(
                    self.data_loader_train.dataset.converter.decode(
                        y_pred[:x_lens[0], 0], x_lens[:1])))
                if self.scheduler:
                    self.logger.info("\tCurrent LR: {:.6f}".format(
                        self.scheduler.get_lr()[0]))
            self.step += 1
        return losses

    def validate(self, data_loader, load=False):
        self.logger.info('---------------Validation------------------')
        if load:
            self.load_model()
        self.model.eval()
        losses = []
        with torch.no_grad():
            total_err, total_len = 0, 0
            for batch_idx, (xx_pad, yy_pad, x_lens,
                            y_lens) in enumerate(data_loader):
                # distribute data to device
                xx_pad, yy_pad = xx_pad.to(self.device), yy_pad.to(self.device)
                x_lens, y_lens = x_lens.to(self.device), y_lens.to(self.device)

                output = self.model(xx_pad, x_lens)
                output = output.permute(1, 0, 2).log_softmax(2)
                loss = self.ctc_loss(output, yy_pad, x_lens, y_lens)
                y_pred = torch.max(output, 2)[1]
                losses.append(loss.item())

                # to compute accuracy
                # TODO: need to work for batch_size > 1
                gt = data_loader.dataset.converter.decode(
                    yy_pad[0, :y_lens[0]], y_lens[:1])
                pred = data_loader.dataset.converter.decode(
                    y_pred[:x_lens[0], 0], x_lens[:1])
                err, length = cer(gt, pred)
                if err > length:
                    err = length
                total_err += err
                total_len += length

                # log intermediate results
                if (batch_idx + 1) % self.opts.log_interval == 0:
                    self.logger.info('\tGround Truth: {}'.format(gt))
                    self.logger.info('\tModel Output: {}'.format(pred))

            cur_val_loss = sum(losses) / len(losses)
            cur_val_cer = total_err / total_len

            self.logger.info(
                '\tcurrent_validation_loss: {}'.format(cur_val_loss))
            self.logger.info(
                '\tcurrent_validation_cer: {}'.format(cur_val_cer))

            if self.best_val_cer > cur_val_cer:
                self.best_val_cer = cur_val_cer
                return True
            return False

    def log_time(self, duration, batch_idx, loss):
        samples_per_sec = self.opts.batch_size / duration
        time_sofar = time.time() - self.start_time
        training_time_left = (
            (self.num_total_steps - self.step) /
            (self.step - self.start_step)) * time_sofar if self.step > 0 else 0
        print_string = "epoch {:>3} | batch [{:>4}/{:>4}] | examples/s: {:5.1f}" + \
            " | loss: {:.5f} | time elapsed: {} | time left: {}"
        self.logger.info(
            print_string.format(self.epoch, batch_idx * self.opts.batch_size,
                                len(self.data_loader_train.dataset),
                                samples_per_sec, loss,
                                sec_to_hm_str(time_sofar),
                                sec_to_hm_str(training_time_left)))

    def save_opts(self):
        to_save = self.opts.__dict__.copy()
        with open(os.path.join(self.save_dir, 'opts.json'), 'w') as f:
            json.dump(to_save, f, indent=2)

    def save_model(self, is_best):
        save_folder = os.path.join(
            self.save_dir, "models",
            "weights_{}{}".format(self.epoch, "_best" if is_best else ""))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        save_path = os.path.join(save_folder, "{}.pth".format("model"))
        to_save = self.model_without_dp.state_dict()
        torch.save(to_save, save_path)

        save_path = os.path.join(save_folder, "{}.pth".format("optimizer"))
        torch.save(self.optimizer.state_dict(), save_path)

        if self.scheduler:
            save_path = os.path.join(save_folder, "{}.pth".format("scheduler"))
            torch.save(self.scheduler.state_dict(), save_path)

    def load_model(self):
        self.opts.load_dir = os.path.expanduser(self.opts.load_dir)

        assert os.path.isdir(self.opts.load_dir), \
            "Cannot find directory {}".format(self.opts.load_dir)
        print("loading model from directory {}".format(self.opts.load_dir))

        # loading model state
        path = os.path.join(self.opts.load_dir, "model.pth")
        pretrained_dict = torch.load(path)
        self.model_without_dp.load_state_dict(pretrained_dict)

        # loading optimizer state
        optimizer_load_path = os.path.join(self.opts.load_dir, "optimizer.pth")
        if os.path.isfile(optimizer_load_path):
            print("Loading Optimizer weights")
            optimizer_dict = torch.load(optimizer_load_path)
            self.optimizer.load_state_dict(optimizer_dict)
        else:
            print(
                "Cannot find Optimizer weights so Optimizer is randomly initialized"
            )

        # loading scheduler state
        if self.scheduler:
            scheduler_load_path = os.path.join(self.opts.load_dir,
                                               "scheduler.pth")
            if os.path.isfile(scheduler_load_path):
                print("Loading Scheduler weights")
                scheduler_dict = torch.load(scheduler_load_path)
                self.scheduler.load_state_dict(scheduler_dict)
            else:
                print(
                    "Cannot find Scheduler weights so Scheduler is initialized as arranged"
                )
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_dataset(name=args.dataset)

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(probability=0.5, mean=[0.0, 0.0, 0.0]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    if args.xent_only:
        trainloader = DataLoader(
            VideoDataset(dataset.train,
                         seq_len=args.seq_len,
                         sample='random',
                         transform=transform_train),
            batch_size=args.train_batch,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=True,
        )
    else:
        trainloader = DataLoader(
            VideoDataset(dataset.train,
                         seq_len=args.seq_len,
                         sample='random',
                         transform=transform_train),
            sampler=RandomIdentitySampler(dataset.train,
                                          num_instances=args.num_instances),
            batch_size=args.train_batch,
            num_workers=args.workers,
            pin_memory=pin_memory,
            drop_last=True,
        )

    queryloader = DataLoader(
        VideoDataset(dataset.query,
                     seq_len=args.seq_len,
                     sample='dense',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     seq_len=args.seq_len,
                     sample='dense',
                     transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    if args.arch == 'resnet503d':
        model = resnet3d.resnet50(num_classes=dataset.num_train_pids,
                                  sample_width=args.width,
                                  sample_height=args.height,
                                  sample_duration=args.seq_len)
        if not os.path.exists(args.pretrained_model):
            raise IOError("Can't find pretrained model: {}".format(
                args.pretrained_model))
        print("Loading checkpoint from '{}'".format(args.pretrained_model))
        checkpoint = torch.load(args.pretrained_model)
        state_dict = {}
        for key in checkpoint['state_dict']:
            if 'fc' in key: continue
            state_dict[key.partition("module.")
                       [2]] = checkpoint['state_dict'][key]
        model.load_state_dict(state_dict, strict=False)
    else:
        #model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, dropout=args.dropout, nhid=args.nhid, nheads=args.nheads, concat=args.concat, loss={'xent', 'htri'})
        model = models.init_model(name=args.arch,
                                  pool_size=8,
                                  input_shape=2048,
                                  n_classes=dataset.num_train_pids,
                                  loss={'xent', 'htri'})
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    if os.path.exists(args.pretrained_model):
        print("Loading checkpoint from '{}'".format(args.pretrained_model))
        checkpoint = torch.load(args.pretrained_model)
        model_dict = model.state_dict()
        pretrain_dict = checkpoint['state_dict']
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items() if k in model_dict
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)

    criterion_xent = CrossEntropyLabelSmooth(
        num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    criterion_htri = TripletLoss(margin=args.margin)

    #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    if args.stepsize > 0:
        if args.warmup:
            scheduler = WarmupMultiStepLR(optimizer, [200, 400, 600])
        else:
            scheduler = lr_scheduler.StepLR(optimizer,
                                            step_size=args.stepsize,
                                            gamma=args.gamma)
    start_epoch = args.start_epoch

    if args.resume:
        print("Loading checkpoint from '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch'] + 1
    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        test(model, queryloader, galleryloader, args.pool, use_gpu)
        return

    start_time = time.time()
    best_rank1 = -np.inf
    if args.arch == 'resnet503d':
        torch.backends.cudnn.benchmark = False
    '''
    adj1 = build_adj_full_full(4, args.part1)
    adj2 = build_adj_full_full(4, args.part2)
    adj3 = build_adj_full_full(4, args.part3)
    if use_gpu:
        adj1 = adj1.cuda()
        adj2 = adj2.cuda()
        adj2 = adj2.cuda()
    adj1 = Variable(adj1)
    adj2 = Variable(adj2)
    adj3 = Variable(adj3)
    '''

    torch.cuda.empty_cache()
    for epoch in range(start_epoch, args.max_epoch):
        print("==> Epoch {}/{}  lr:{}".format(epoch + 1, args.max_epoch,
                                              scheduler.get_lr()[0]))

        #train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, adj1, adj2, adj3)
        train(model, criterion_xent, criterion_htri, optimizer, trainloader,
              use_gpu)

        if args.stepsize > 0: scheduler.step()

        if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                epoch + 1) == args.max_epoch:
            print("==> Test")
            #rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu, adj1, adj2, adj3)
            rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu)
            is_best = rank1 > best_rank1
            if is_best: best_rank1 = rank1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
        torch.cuda.empty_cache()

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))