示例#1
0
def main():
    model = Tacotron().to(DEVICE)
    print('Model {} is working...'.format(model.name))
    print('{} threads are used...'.format(torch.get_num_threads()))
    ckpt_dir = os.path.join(args.logdir, model.name)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_decay_step // 10,
                       gamma=0.933)  # around 1/2 per decay step

    if not os.path.exists(ckpt_dir):
        os.makedirs(os.path.join(ckpt_dir, 'A', 'train'))
    elif not os.path.exists(os.path.join(ckpt_dir, 'ckpt.csv')):
        shutil.rmtree(ckpt_dir)
        os.makedirs(os.path.join(ckpt_dir, 'A', 'train'))
    else:
        print('Already exists. Retrain the model.')
        ckpt = pd.read_csv(os.path.join(ckpt_dir, 'ckpt.csv'),
                           sep=',',
                           header=None)
        ckpt.columns = ['models', 'loss']
        ckpt = ckpt.sort_values(by='loss', ascending=True)
        state = torch.load(os.path.join(ckpt_dir, ckpt.models.loc[0]))
        model.load_state_dict(state['model'])
        args.global_step = state['global_step']
        optimizer.load_state_dict(state['optimizer'])
        scheduler.load_state_dict(state['scheduler'])

    # model = torch.nn.DataParallel(model, device_ids=list(range(args.no_gpu))).to(DEVICE)

    dataset = SpeechDataset(args.data_path,
                            args.meta_train,
                            model.name,
                            mem_mode=args.mem_mode)
    validset = SpeechDataset(args.data_path,
                             args.meta_eval,
                             model.name,
                             mem_mode=args.mem_mode)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             collate_fn=collate_fn,
                             drop_last=True,
                             pin_memory=True)
    valid_loader = DataLoader(dataset=validset,
                              batch_size=args.test_batch,
                              shuffle=False,
                              collate_fn=collate_fn,
                              pin_memory=True)

    writer = SummaryWriter(ckpt_dir)
    train(model,
          data_loader,
          valid_loader,
          optimizer,
          scheduler,
          batch_size=args.batch_size,
          ckpt_dir=ckpt_dir,
          writer=writer)
    return None
示例#2
0
def backup_init(args):
    checkpoint = torch.load(args.model_file)  # 加载断点

    net = SeSface()
    net.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    net.to(args.device)
    if len(args.gpu_ids) > 1:
        srnet = nn.DataParallel(net)

    optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999))
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数

    last_epoch = checkpoint['epoch']  # 设置开始的epoch

    scheduler = StepLR(optimizer,
                       step_size=args.decay_step,
                       gamma=0.5,
                       last_epoch=last_epoch)
    scheduler.load_state_dict(checkpoint['scheduler'])
    return net, optimizer, last_epoch, scheduler
示例#3
0
def train(model, optimizer, loss_fn, epochs, train_loader, device, model_chckpt_path, checkpoint_save_interval,
          model_path, load_chckpt, log_interval):
    epoch_start = 0

    scheduler = StepLR(optimizer, int(epochs * 0.5), 0.1)

    if load_chckpt and os.path.isfile(model_chckpt_path):
        checkpoint = torch.load(model_chckpt_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        epoch_start = checkpoint['epoch']
        print("Training checkpoints found. Starting training from epoch %d." % epoch_start)

    model.train()
    for epoch in range(epoch_start, epochs):
        running_loss = 0.0
        processed_items = 0
        correct_predictions = 0
        for batch_num, (images, targets) in enumerate(train_loader):
            images, targets = images.to(device), targets.to(device)
            out = model(images)
            optimizer.zero_grad()
            loss = loss_fn(out, targets)
            loss.backward()
            optimizer.step()

            _, correct = calculate_correct_predictions(targets, out)
            running_loss += loss.item()
            processed_items += out.size()[0]
            correct_predictions += correct

            if (batch_num + 1) % log_interval == 0:
                print('[Epoch %d, Batch %4d] Loss: %.10f, Accuracy: %.5f' %
                      (epoch + 1, batch_num + 1, running_loss / processed_items, correct_predictions / processed_items))

        if epoch % checkpoint_save_interval == 0:
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()}, model_chckpt_path)
    torch.save(model.state_dict(), model_path)
示例#4
0
def backup_init(args):
    save_filename = 'backup.pth'
    save_dir = os.path.join(args.backup_dir, args.name)
    save_path = os.path.join(save_dir, save_filename)
    checkpoint = torch.load(save_path)  # 加载断点

    ## Setup SRNet
    srnet = edsr.Edsr()
    srnet.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    srnet.to(args.device)
    if len(args.gpu_ids) > 1:
        srnet = nn.DataParallel(srnet)

    optimizer = optim.Adam(srnet.parameters(), lr=args.lr, betas=(0.9, 0.999))
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数

    last_epoch = checkpoint['epoch']  # 设置开始的epoch

    scheduler = StepLR(optimizer, step_size=args.decay_step, gamma=0.5, last_epoch=last_epoch)
    scheduler.load_state_dict(checkpoint['scheduler'])
    return srnet, optimizer, last_epoch, scheduler
示例#5
0
    def on_validation_epoch_end(self, trainer, pl_module):
        """
        This function is called after every validation epoch
        """

        # check of the validation accuracy has decreased
        if pl_module.last_val_acc < self.last_val_acc:
            # divide the learning rate by the specified factor
            state_dict = trainer.optimizers[0].state_dict()
            state_dict['param_groups'][0]['lr'] = state_dict['param_groups'][
                0]['lr'] / self.decrease_factor
            new_optimizer = torch.optim.SGD(
                [{
                    'params': pl_module.encoder.parameters()
                }, {
                    'params': pl_module.classifier.parameters()
                }],
                lr=state_dict['param_groups'][0]['lr'])
            new_optimizer.load_state_dict(state_dict)

            # update scheduler
            scheduler_state_dict = trainer.lr_schedulers[0][
                'scheduler'].state_dict()
            new_step_scheduler = StepLR(optimizer=new_optimizer,
                                        step_size=1,
                                        gamma=scheduler_state_dict['gamma'])
            new_step_scheduler.load_state_dict(scheduler_state_dict)
            new_scheduler = {
                'scheduler': new_step_scheduler,
                'name': 'learning_rate'
            }

            # use the new scheduler and optimizer
            trainer.optimizers = [new_optimizer]
            trainer.lr_schedulers = trainer.configure_schedulers(
                [new_scheduler])

        # save the validation accuracy
        self.last_val_acc = pl_module.last_val_acc
示例#6
0
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    # Data loading
    print('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    # Create model
    print('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=device)

    if args.pruned:
        state_dict = checkpoint['state_dict_s']
        if args.arch == 'vgg':
            cfg = checkpoint['cfg']
            model = vgg_16_bn_sparse(cfg=cfg).to(device)
        # pruned = sum([1 for m in mask if mask == 0])
        # print(f"Pruned / Total: {pruned} / {len(mask)}")
        elif args.arch == 'resnet':
            mask = checkpoint['mask']
            model = resnet_56_sparse(has_mask=mask).to(device)

        elif args.arch == 'densenet':
            filters = checkpoint['filters']
            indexes = checkpoint['indexes']
            model = densenet_40_sparse(filters=filters,
                                       indexes=indexes).to(device)
        elif args.arch == 'googlenet':
            mask = checkpoint['mask']
            model = googlenet_sparse(has_mask=mask).to(device)
        model.load_state_dict(state_dict)
    else:
        model = import_module('utils.preprocess').__dict__[f'{args.arch}'](
            args, checkpoint['state_dict_s'])
    '''
    print_logger.info(f"Simply test after pruning...")
    test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test, 0)
    '''
    if args.test_only:
        return

    if args.keep_grad:
        for name, weight in model.named_parameters():
            if 'mask' in name:
                weight.requires_grad = False

    train_param = [
        param for name, param in model.named_parameters() if 'mask' not in name
    ]

    optimizer = optim.SGD(train_param,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=device)
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    for epoch in range(start_epoch, args.num_epochs):
        scheduler.step(epoch)

        train(args, loader.loader_train, model, criterion, optimizer,
              writer_train, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model,
                                      criterion, writer_test, epoch)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, is_best)

    print_logger.info(
        f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    # Model compression info
    flops, params = get_model_complexity_info(model.to(device), (3, 32, 32),
                                              as_strings=False,
                                              print_per_layer_stat=True)
    compressionInfo(flops, params)
示例#7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with DGL')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='random seed to use (default: 42)')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help='GNN to use, which can be from '
        '[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset',
                        action='store_true',
                        help='use 10% of the training set for training')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory. If not specified, '
                        'tensorboard will not be used.')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = SampleDglPCQM4MDataset(root='dataset/')

    # split_idx['train'], split_idx['valid'], split_idx['test']
    # separately gives a 1D int64 tensor
    split_idx = dataset.get_idx_split()
    split_idx["train"] = split_idx["train"].type(torch.LongTensor)
    split_idx["test"] = split_idx["test"].type(torch.LongTensor)
    split_idx["valid"] = split_idx["valid"].type(torch.LongTensor)

    ### automatic evaluator.
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_loader = DataLoader(dataset[split_idx["train"][subset_idx]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_dgl)
    else:
        train_loader = DataLoader(dataset[split_idx["train"]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_dgl)

    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=collate_dgl)

    if args.save_test_dir is not '':
        test_loader = DataLoader(dataset[split_idx["test"]],
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_dgl)

    if args.checkpoint_dir is not '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual-diffpool':
        model = DiffPoolGNN(gnn_type='gin', virtual_node=True,
                            **shared_params).to(device)
    elif args.gnn == 'gin-virtual-bayes-diffpool':
        model = BayesDiffPoolGNN(gnn_type='gin',
                                 virtual_node=True,
                                 **shared_params).to(device)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir is not '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)
    """ load from latest checkpoint """
    # start epoch (default = 1, unless resuming training)
    firstEpoch = 1
    # check if checkpoint exist -> load model
    checkpointFile = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
    if os.path.exists(checkpointFile):
        # load checkpoint file
        checkpointData = torch.load(checkpointFile)
        firstEpoch = checkpointData["epoch"]
        model.load_state_dict(checkpointData["model_state_dict"])
        optimizer.load_state_dict(checkpointData["optimizer_state_dict"])
        scheduler.load_state_dict(checkpointData["scheduler_state_dict"])
        best_valid_mae = checkpointData["best_val_mae"]
        num_params = checkpointData["num_params"]
        print(
            "Loaded existing weights from {}. Continuing from epoch: {} with best valid MAE: {}"
            .format(checkpointFile, firstEpoch, best_valid_mae))

    for epoch in range(firstEpoch, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, device, train_loader, optimizer, args.gnn)

        print('Evaluating...')
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir is not '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not '':
                print('Saving checkpoint...')
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           os.path.join(args.checkpoint_dir, 'checkpoint.pt'))

            if args.save_test_dir is not '':
                print('Predicting on test data...')
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

    if args.log_dir is not '':
        writer.close()
示例#8
0
class Trainer():
    def __init__(self, cfg_data, pwd):

        self.cfg_data = cfg_data
        self.train_loader, self.val_loader, self.restore_transform = datasets.loading_data(
            cfg.DATASET)

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET
        self.net = Crowd_locator(cfg.NET, cfg.GPU_ID, pretrained=True)

        if cfg.OPT == 'Adam':
            self.optimizer = optim.Adam([{
                'params':
                self.net.Extractor.parameters(),
                'lr':
                cfg.LR_BASE_NET,
                'weight_decay':
                1e-5
            }, {
                'params': self.net.Binar.parameters(),
                'lr': cfg.LR_BM_NET
            }])

        self.scheduler = StepLR(self.optimizer,
                                step_size=cfg.NUM_EPOCH_LR_DECAY,
                                gamma=cfg.LR_DECAY)
        self.train_record = {
            'best_F1': 0,
            'best_Pre': 0,
            'best_Rec': 0,
            'best_mae': 1e20,
            'best_mse': 1e20,
            'best_nae': 1e20,
            'best_model_name': ''
        }
        self.timer = {
            'iter time': Timer(),
            'train time': Timer(),
            'val time': Timer()
        }

        self.epoch = 0
        self.i_tb = 0
        self.num_iters = cfg.MAX_EPOCH * np.int(len(self.train_loader))

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.num_iters = latest_state['num_iters']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']
            print("Finish loading resume mode")
        self.writer, self.log_txt = logger(self.exp_path,
                                           self.exp_name,
                                           self.pwd,
                                           ['exp', 'figure', 'img', 'vis'],
                                           resume=cfg.RESUME)

    def forward(self):
        # self.validate()
        for epoch in range(self.epoch, cfg.MAX_EPOCH):
            self.epoch = epoch
            # training
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print('train time: {:.2f}s'.format(self.timer['train time'].diff))
            print('=' * 20)

            # validation
            if epoch % cfg.VAL_FREQ == 0 and epoch > cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                self.validate()
                self.timer['val time'].toc(average=False)
                print('val time: {:.2f}s'.format(self.timer['val time'].diff))

            # if epoch > cfg.LR_DECAY_START:
            #     self.scheduler.step()

    def train(self):  # training for all datasets
        self.net.train()

        for i, data in enumerate(self.train_loader, 0):
            self.i_tb += 1
            self.timer['iter time'].tic()
            img, gt_map = data

            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            self.optimizer.zero_grad()
            threshold_matrix, pre_map, binar_map = self.net(img, gt_map)
            head_map_loss, binar_map_loss = self.net.loss

            all_loss = head_map_loss + binar_map_loss
            all_loss.backward()
            self.optimizer.step()

            lr1, lr2 = adjust_learning_rate(self.optimizer, cfg.LR_BASE_NET,
                                            cfg.LR_BM_NET, self.num_iters,
                                            self.i_tb)

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.writer.add_scalar('train_lr1', lr1, self.i_tb)
                self.writer.add_scalar('train_lr2', lr2, self.i_tb)
                self.writer.add_scalar('train_loss', head_map_loss.item(),
                                       self.i_tb)
                self.writer.add_scalar('Binar_loss', binar_map_loss.item(),
                                       self.i_tb)
                if len(cfg.GPU_ID) > 1:
                    self.writer.add_scalar(
                        'weight', self.net.Binar.module.weight.data.item(),
                        self.i_tb)
                    self.writer.add_scalar(
                        'bias', self.net.Binar.module.bias.data.item(),
                        self.i_tb)
                else:
                    self.writer.add_scalar('weight',
                                           self.net.Binar.weight.data.item(),
                                           self.i_tb)
                    self.writer.add_scalar('bias',
                                           self.net.Binar.bias.data.item(),
                                           self.i_tb)

                self.timer['iter time'].toc(average=False)
                print( '[ep %d][it %d][loss %.4f][lr1 %.4f][lr2 %.4f][%.2fs]' % \
                        (self.epoch + 1, i + 1, head_map_loss.item(), self.optimizer.param_groups[0]['lr']*10000, self.optimizer.param_groups[1]['lr']*10000,self.timer['iter time'].diff) )
                print('       [t-max: %.3f t-min: %.3f]' %
                      (threshold_matrix.max().item(),
                       threshold_matrix.min().item()))
            if i % 100 == 0:
                box_pre, boxes = self.get_boxInfo_from_Binar_map(
                    binar_map[0].detach().cpu().numpy())
                vis_results('tmp_vis', 0, self.writer, self.restore_transform, img, pre_map[0].detach().cpu().numpy(), \
                                 gt_map[0].detach().cpu().numpy(),binar_map.detach().cpu().numpy(),
                                 threshold_matrix.detach().cpu().numpy(),boxes)

    def get_boxInfo_from_Binar_map(self, Binar_numpy, min_area=3):
        Binar_numpy = Binar_numpy.squeeze().astype(np.uint8)
        assert Binar_numpy.ndim == 2
        cnt, labels, stats, centroids = cv2.connectedComponentsWithStats(
            Binar_numpy, connectivity=4)  # centriod (w,h)

        boxes = stats[1:, :]
        points = centroids[1:, :]
        index = (boxes[:, 4] >= min_area)
        boxes = boxes[index]
        points = points[index]
        pre_data = {'num': len(points), 'points': points}
        return pre_data, boxes

    def validate(self):
        self.net.eval()
        num_classes = 6
        losses = AverageMeter()
        cnt_errors = {
            'mae': AverageMeter(),
            'mse': AverageMeter(),
            'nae': AverageMeter()
        }
        metrics_s = {
            'tp': AverageMeter(),
            'fp': AverageMeter(),
            'fn': AverageMeter(),
            'tp_c': AverageCategoryMeter(num_classes),
            'fn_c': AverageCategoryMeter(num_classes)
        }
        metrics_l = {
            'tp': AverageMeter(),
            'fp': AverageMeter(),
            'fn': AverageMeter(),
            'tp_c': AverageCategoryMeter(num_classes),
            'fn_c': AverageCategoryMeter(num_classes)
        }

        c_maes = {
            'level': AverageCategoryMeter(5),
            'illum': AverageCategoryMeter(4)
        }
        c_mses = {
            'level': AverageCategoryMeter(5),
            'illum': AverageCategoryMeter(4)
        }
        c_naes = {
            'level': AverageCategoryMeter(5),
            'illum': AverageCategoryMeter(4)
        }
        gen_tqdm = tqdm(self.val_loader)
        for vi, data in enumerate(gen_tqdm, 0):
            img, dot_map, gt_data = data
            slice_h, slice_w = self.cfg_data.TRAIN_SIZE

            with torch.no_grad():
                img = Variable(img).cuda()
                dot_map = Variable(dot_map).cuda()
                # crop the img and gt_map with a max stride on x and y axis
                # size: HW: __C_NWPU.TRAIN_SIZE
                # stack them with a the batchsize: __C_NWPU.TRAIN_BATCH_SIZE
                crop_imgs, crop_gt, crop_masks = [], [], []
                b, c, h, w = img.shape

                if h * w < slice_h * 2 * slice_w * 2 and h % 16 == 0 and w % 16 == 0:
                    [pred_threshold, pred_map, __] = [
                        i.cpu()
                        for i in self.net(img, mask_gt=None, mode='val')
                    ]
                else:
                    if h % 16 != 0:
                        pad_dims = (0, 0, 0, 16 - h % 16)
                        h = (h // 16 + 1) * 16
                        img = F.pad(img, pad_dims, "constant")
                        dot_map = F.pad(dot_map, pad_dims, "constant")

                    if w % 16 != 0:
                        pad_dims = (0, 16 - w % 16, 0, 0)
                        w = (w // 16 + 1) * 16
                        img = F.pad(img, pad_dims, "constant")
                        dot_map = F.pad(dot_map, pad_dims, "constant")

                    assert img.size()[2:] == dot_map.size()[2:]

                    for i in range(0, h, slice_h):
                        h_start, h_end = max(min(h - slice_h, i),
                                             0), min(h, i + slice_h)
                        for j in range(0, w, slice_w):
                            w_start, w_end = max(min(w - slice_w, j),
                                                 0), min(w, j + slice_w)

                            crop_imgs.append(img[:, :, h_start:h_end,
                                                 w_start:w_end])
                            crop_gt.append(dot_map[:, :, h_start:h_end,
                                                   w_start:w_end])
                            mask = torch.zeros_like(dot_map).cpu()
                            mask[:, :, h_start:h_end, w_start:w_end].fill_(1.0)
                            crop_masks.append(mask)
                    crop_imgs, crop_gt, crop_masks = map(
                        lambda x: torch.cat(x, dim=0),
                        (crop_imgs, crop_gt, crop_masks))

                    # forward may need repeatng
                    crop_preds, crop_thresholds = [], []
                    nz, period = crop_imgs.size(
                        0), self.cfg_data.TRAIN_BATCH_SIZE
                    for i in range(0, nz, period):
                        [crop_threshold, crop_pred, __] = [
                            i.cpu()
                            for i in self.net(crop_imgs[i:min(nz, i + period)],
                                              mask_gt=None,
                                              mode='val')
                        ]
                        crop_preds.append(crop_pred)
                        crop_thresholds.append(crop_threshold)

                    crop_preds = torch.cat(crop_preds, dim=0)
                    crop_thresholds = torch.cat(crop_thresholds, dim=0)

                    # splice them to the original size
                    idx = 0
                    pred_map = torch.zeros_like(dot_map).cpu().float()
                    pred_threshold = torch.zeros_like(dot_map).cpu().float()
                    for i in range(0, h, slice_h):
                        h_start, h_end = max(min(h - slice_h, i),
                                             0), min(h, i + slice_h)
                        for j in range(0, w, slice_w):
                            w_start, w_end = max(min(w - slice_w, j),
                                                 0), min(w, j + slice_w)
                            pred_map[:, :, h_start:h_end,
                                     w_start:w_end] += crop_preds[idx]
                            pred_threshold[:, :, h_start:h_end,
                                           w_start:w_end] += crop_thresholds[
                                               idx]
                            idx += 1

                # for the overlapping area, compute average value
                    mask = crop_masks.sum(dim=0)
                    pred_map = (pred_map / mask)
                    pred_threshold = (pred_threshold / mask)

                # binar_map = self.net.Binar(pred_map.cuda(), pred_threshold.cuda()).cpu()
                a = torch.ones_like(pred_map)
                b = torch.zeros_like(pred_map)
                binar_map = torch.where(pred_map >= pred_threshold, a, b)

                dot_map = dot_map.cpu()
                loss = F.mse_loss(pred_map, dot_map)

                losses.update(loss.item())
                binar_map = binar_map.numpy()
                pred_data, boxes = self.get_boxInfo_from_Binar_map(binar_map)
                # print(pred_data, gt_data)

                tp_s, fp_s, fn_s, tp_c_s, fn_c_s, tp_l, fp_l, fn_l, tp_c_l, fn_c_l = eval_metrics(
                    num_classes, pred_data, gt_data)

                metrics_s['tp'].update(tp_s)
                metrics_s['fp'].update(fp_s)
                metrics_s['fn'].update(fn_s)
                metrics_s['tp_c'].update(tp_c_s)
                metrics_s['fn_c'].update(fn_c_s)
                metrics_l['tp'].update(tp_l)
                metrics_l['fp'].update(fp_l)
                metrics_l['fn'].update(fn_l)
                metrics_l['tp_c'].update(tp_c_l)
                metrics_l['fn_c'].update(fn_c_l)

                #    -----------Counting performance------------------
                gt_count, pred_cnt = gt_data['num'].numpy().astype(
                    float), pred_data['num']
                s_mae = abs(gt_count - pred_cnt)
                s_mse = ((gt_count - pred_cnt) * (gt_count - pred_cnt))
                cnt_errors['mae'].update(s_mae)
                cnt_errors['mse'].update(s_mse)
                if gt_count != 0:
                    s_nae = (abs(gt_count - pred_cnt) / gt_count)
                    cnt_errors['nae'].update(s_nae)

                if vi == 0:
                    vis_results(self.exp_name, self.epoch, self.writer,
                                self.restore_transform, img, pred_map.numpy(),
                                dot_map.numpy(), binar_map,
                                pred_threshold.numpy(), boxes)

        ap_s = metrics_s['tp'].sum / (metrics_s['tp'].sum +
                                      metrics_s['fp'].sum + 1e-20)
        ar_s = metrics_s['tp'].sum / (metrics_s['tp'].sum +
                                      metrics_s['fn'].sum + 1e-20)
        f1m_s = 2 * ap_s * ar_s / (ap_s + ar_s + 1e-20)
        ar_c_s = metrics_s['tp_c'].sum / (metrics_s['tp_c'].sum +
                                          metrics_s['fn_c'].sum + 1e-20)

        ap_l = metrics_l['tp'].sum / (metrics_l['tp'].sum +
                                      metrics_l['fp'].sum + 1e-20)
        ar_l = metrics_l['tp'].sum / (metrics_l['tp'].sum +
                                      metrics_l['fn'].sum + 1e-20)
        f1m_l = 2 * ap_l * ar_l / (ap_l + ar_l + 1e-20)
        ar_c_l = metrics_l['tp_c'].sum / (metrics_l['tp_c'].sum +
                                          metrics_l['fn_c'].sum + 1e-20)

        loss = losses.avg
        mae = cnt_errors['mae'].avg
        mse = np.sqrt(cnt_errors['mse'].avg)
        nae = cnt_errors['nae'].avg

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)

        self.writer.add_scalar('F1', f1m_l, self.epoch + 1)
        self.writer.add_scalar('Pre', ap_l, self.epoch + 1)
        self.writer.add_scalar('Rec', ar_l, self.epoch + 1)

        self.writer.add_scalar('overall_mae', mae, self.epoch + 1)
        self.writer.add_scalar('overall_mse', mse, self.epoch + 1)
        self.writer.add_scalar('overall_nae', nae, self.epoch + 1)

        self.train_record = update_model(
            self, [f1m_l, ap_l, ar_l, mae, mse, nae, loss])

        print_NWPU_summary(self, [f1m_l, ap_l, ar_l, mae, mse, nae, loss])
def main():  
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True,
                                      data_sel=(0, 99965071), # 80% 트레인
                                      batch_size=TR_BATCH_SZ,
                                      shuffle=True,
                                      seq_mode=True) # seq_mode implemented  
    
    mval_loader  = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True, # True, because we use part of trainset as testset
                                      data_sel=(99965071, 101065071),#104965071),#(99965071, 124950714), # 20%를 테스트
                                      batch_size=TS_BATCH_SZ,
                                      shuffle=False,
                                      seq_mode=True) 
    
    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.8)  
    
    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0        
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),key=os.path.getctime)  
        checkpoint = torch.load(latest_fpath, map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(latest_fpath,checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']
        
    # Train    
    for epoch in trange(START_EPOCH, EPOCHS, desc='epochs', position=0, ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query    = 0
        total_trloss_qlog = 0
        total_trloss_skip = 0
        total_trloss   = 0
        for session in trange(len(tr_sessions_iter), desc='sessions', position=1, ascii=True):
            SM.train();
            x, labels, y_mask, num_items, index = tr_sessions_iter.next() # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS
            
            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...        
            num_support = num_items[:,0].detach().numpy().flatten() # If num_items was odd number, query has one more item. 
            num_query   = num_items[:,1].detach().numpy().flatten()
            batch_sz    = num_items.shape[0]
    
            # x: bx70*20
            x = x.permute(0,2,1)
            
            # Prepare ground truth log and label, y
            y_qlog = x[:,:41,:].clone() # bx41*20
            y_skip = labels.clone() #bx20
            y_mask_qlog = y_mask.unsqueeze(1).repeat(1,41,1) #bx41*20
            y_mask_skip = y_mask #bx20
    
            # log shift: bx41*20
            log_shift = torch.zeros(batch_sz,41,20)
            log_shift[:,:,1:] = x[:,:41,:-1]
            log_shift[:,:,11:] = 0 # DELETE LOG QUE
            
            # labels_shift: bx1*20(model can only observe past labels)
            labels_shift = torch.zeros(batch_sz,1,20)
            labels_shift[:,0,1:] = labels[:,:-1].float()
            labels_shift[:,0,11:] = 0 #!!! NOLABEL for previous QUERY
            
            # support/query state labels: bx1*20
            sq_state = torch.zeros(batch_sz,1,20)
            sq_state[:,0,:11] = 1
            
            # Pack x: bx72*20 (or bx32*20 if not using sup_logs)
            x = Variable(torch.cat((log_shift, x[:,41:,:], labels_shift, sq_state), 1)).cuda(GPU) # x: bx72*20
  
            # Forward & update
            y_hat_qlog, y_hat_skip = SM(x) # y_hat: b*20
            
            # Calcultate BCE loss
            loss_qlog = F.binary_cross_entropy_with_logits(input=y_hat_qlog.cuda(GPU)*y_mask_qlog.cuda(GPU),
                                                           target=y_qlog.cuda(GPU)*y_mask_qlog.cuda(GPU))
            loss_skip = F.binary_cross_entropy_with_logits(input=y_hat_skip.cuda(GPU)*y_mask_skip.cuda(GPU),
                                                           target=y_skip.cuda(GPU)*y_mask_skip.cuda(GPU))
            loss      = loss_qlog + loss_skip
            total_trloss_qlog += loss_qlog.item()
            total_trloss_skip += loss_skip.item()
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()
            
            # Decision
            y_prob = torch.sigmoid(y_hat_skip.detach()*y_mask_skip.cuda(GPU)).cpu().numpy() # bx20               
            y_pred = (y_prob[:,10:]>=0.5).astype(np.int) # bx10
            y_numpy = y_skip[:,10:].numpy() # bx10
            
            # Label Acc*
            total_corrects += np.sum((y_pred==y_numpy)*y_mask_skip[:,10:].numpy())
            total_query += np.sum(num_query)
#            # Log generation Acc*
#            y_qlog_mask = y_mask[:,:41,10:]
            
            # Restore GPU memory
            del loss, loss_qlog, loss_skip, y_hat_qlog, y_hat_skip 
    
            if (session+1)%500 == 0:
                hist_trloss_qlog.append(total_trloss_qlog/500) #!
                hist_trloss_skip.append(total_trloss_skip/500) #!
                hist_trloss.append(total_trloss/500)
                hist_tracc.append(total_corrects/total_query)
                # Prepare display
                sample_sup = labels[0,(10-num_support[0]):10].long().numpy().flatten() 
                sample_que = y_numpy[0,:num_query[0]].astype(int)
                sample_pred = y_pred[0,:num_query[0]]
                sample_prob = y_prob[0,10:10+num_query[0]]
                tqdm.write("S:" + np.array2string(sample_sup) +'\n'+
                           "Q:" + np.array2string(sample_que) + '\n' +
                           "P:" + np.array2string(sample_pred) + '\n' +
                           "prob:" + np.array2string(sample_prob))
                tqdm.write("tr_session:{0:}  tr_loss(qlog|skip):{1:.6f}({2:.6f}|{3:.6f})  tr_acc:{4:.4f}".format(session,
                           hist_trloss[-1], hist_trloss_qlog[-1], hist_trloss_skip[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query    = 0
                total_trloss   = 0
                total_trloss_qlog   = 0
                total_trloss_skip   = 0
            
            if (session+1)%8000 == 0:
                 # Validation
                 validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                 # Save
                 torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 
                             'hist_trloss_qlog': hist_trloss_qlog, 'hist_trloss_skip': hist_trloss_skip,  'hist_vacc': hist_vacc,
                             'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                             'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1],
                    'hist_trloss_qlog': hist_trloss_qlog, 'hist_trloss_skip': hist_trloss_skip,  'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                    'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
示例#10
0
class Trainer():
    def __init__(self, cfg_data, pwd):

        self.cfg_data = cfg_data
        self.train_loader, self.val_loader, self.restore_transform = datasets.loading_data(cfg.DATASET)

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET
        self.net = CrowdCounter(cfg.GPU_ID,self.net_name).cuda()
        self.optimizer = optim.Adam(self.net.CCN.parameters(), lr=cfg.LR, weight_decay=1e-4)
        # self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer, step_size=cfg.NUM_EPOCH_LR_DECAY, gamma=cfg.LR_DECAY)          

        self.train_record = {'best_mae': 1e20, 'best_mse':1e20, 'best_nae':1e20, 'best_model_name': ''}
        self.timer = {'iter time' : Timer(),'train time' : Timer(),'val time' : Timer()} 

        self.epoch = 0
        self.i_tb = 0
        
        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, 'exp', resume=cfg.RESUME)


    def forward(self):

        # self.validate()
        for epoch in range(self.epoch,cfg.MAX_EPOCH):
            self.epoch = epoch
                
            # training    
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print( 'train time: {:.2f}s'.format(self.timer['train time'].diff) )
            print( '='*20 )

            # validation
            if epoch%cfg.VAL_FREQ==0 or epoch>cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                self.validate()
                self.timer['val time'].toc(average=False)
                print( 'val time: {:.2f}s'.format(self.timer['val time'].diff) )

            if epoch > cfg.LR_DECAY_START:
                self.scheduler.step()


    def train(self): # training for all datasets
        self.net.train()
        for i, data in enumerate(self.train_loader, 0):
            self.timer['iter time'].tic()
            img, gt_map = data
            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            self.optimizer.zero_grad()
            pred_map, _ = self.net(img, gt_map)
            loss = self.net.loss
            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)
                print( '[ep %d][it %d][loss %.4f][lr %.4f][%.2fs]' % \
                        (self.epoch + 1, i + 1, loss.item(), self.optimizer.param_groups[0]['lr']*10000, self.timer['iter time'].diff) )
                print( '        [cnt: gt: %.1f pred: %.2f]' % (gt_map[0].sum().data/self.cfg_data.LOG_PARA, pred_map[0].sum().data/self.cfg_data.LOG_PARA) )           


    def validate(self):

        self.net.eval()
        
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()
        naes = AverageMeter()

        c_maes = {'level':AverageCategoryMeter(5), 'illum':AverageCategoryMeter(4)}
        c_mses = {'level':AverageCategoryMeter(5), 'illum':AverageCategoryMeter(4)}
        c_naes = {'level':AverageCategoryMeter(5), 'illum':AverageCategoryMeter(4)}

        for vi, data in enumerate(self.val_loader, 0):
            img, dot_map, attributes_pt = data

            with torch.no_grad():
                img = Variable(img).cuda()
                dot_map = Variable(dot_map).cuda()

                # crop the img and gt_map with a max stride on x and y axis
                # size: HW: __C_NWPU.TRAIN_SIZE
                # stack them with a the batchsize: __C_NWPU.TRAIN_BATCH_SIZE
                crop_imgs, crop_dots, crop_masks = [], [], []
                b, c, h, w = img.shape
                rh, rw = self.cfg_data.TRAIN_SIZE
                for i in range(0, h, rh):
                    gis, gie = max(min(h-rh, i), 0), min(h, i+rh)
                    for j in range(0, w, rw):
                        gjs, gje = max(min(w-rw, j), 0), min(w, j+rw)
                        crop_imgs.append(img[:, :, gis:gie, gjs:gje])
                        crop_dots.append(dot_map[:, :, gis:gie, gjs:gje])
                        mask = torch.zeros_like(dot_map).cuda()
                        mask[:, :, gis:gie, gjs:gje].fill_(1.0)
                        crop_masks.append(mask)
                crop_imgs, crop_dots, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_dots, crop_masks))

                # forward may need repeatng
                crop_preds, crop_dens = [], []
                nz, bz = crop_imgs.size(0), self.cfg_data.TRAIN_BATCH_SIZE
                for i in range(0, nz, bz):
                    gs, gt = i, min(nz, i+bz)
                    crop_pred, crop_den = self.net.forward(crop_imgs[gs:gt], crop_dots[gs:gt])
                    crop_preds.append(crop_pred)
                    crop_dens.append(crop_den)
                crop_preds = torch.cat(crop_preds, dim=0)
                crop_dens = torch.cat(crop_dens, dim=0)

                # splice them to the original size
                idx = 0
                pred_map = torch.zeros_like(dot_map).cuda()
                den_map = torch.zeros_like(dot_map).cuda()
                for i in range(0, h, rh):
                    gis, gie = max(min(h-rh, i), 0), min(h, i+rh)
                    for j in range(0, w, rw):
                        gjs, gje = max(min(w-rw, j), 0), min(w, j+rw)
                        pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
                        den_map[:, :, gis:gie, gjs:gje] += crop_dens[idx]
                        idx += 1

                # for the overlapping area, compute average value
                mask = crop_masks.sum(dim=0).unsqueeze(0)
                pred_map = pred_map / mask
                den_map = den_map / mask

                pred_map = pred_map.data.cpu().numpy()
                dot_map = dot_map.data.cpu().numpy()
                den_map = den_map.data.cpu().numpy()
                
                pred_cnt = np.sum(pred_map)/self.cfg_data.LOG_PARA
                gt_count = np.sum(dot_map)/self.cfg_data.LOG_PARA

                s_mae = abs(gt_count-pred_cnt)
                s_mse = (gt_count-pred_cnt)*(gt_count-pred_cnt)

                losses.update(self.net.loss.item())
                maes.update(s_mae)
                mses.update(s_mse)
                   
                attributes_pt = attributes_pt.squeeze() 

                c_maes['level'].update(s_mae,attributes_pt[1])
                c_mses['level'].update(s_mse,attributes_pt[1])
                c_maes['illum'].update(s_mae,attributes_pt[0])
                c_mses['illum'].update(s_mse,attributes_pt[0])

                if gt_count != 0:
                    s_nae = abs(gt_count-pred_cnt)/gt_count
                    naes.update(s_nae)
                    c_naes['level'].update(s_nae,attributes_pt[1])
                    c_naes['illum'].update(s_nae,attributes_pt[0])

                if vi==0:
                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, den_map)
            
        loss = losses.avg
        overall_mae = maes.avg
        overall_mse = np.sqrt(mses.avg)
        overall_nae = naes.avg

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('overall_mae', overall_mae, self.epoch + 1)
        self.writer.add_scalar('overall_mse', overall_mse, self.epoch + 1)
        self.writer.add_scalar('overall_nae', overall_nae, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [overall_mae, overall_mse, overall_nae, loss],self.train_record,self.log_txt)

        print_NWPU_summary(self.exp_name, self.log_txt,self.epoch,[overall_mae, overall_mse, overall_nae, loss],self.train_record,c_maes,c_mses, c_naes)
def main():
    # 1. argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=150)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--save_file_name', type=str, default='yolo_v2_vgg_16')
    parser.add_argument('--conf_thres', type=float, default=0.01)
    parser.add_argument('--save_path', type=str, default='./saves')
    parser.add_argument('--start_epoch', type=int, default=0)  # to resume

    opts = parser.parse_args()
    print(opts)

    # 2. device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 3. visdom
    vis = visdom.Visdom()

    # 4. dataset
    train_set = VOC_Dataset(root="D:\Data\VOC_ROOT", split='TRAIN')
    test_set = VOC_Dataset(root="D:\Data\VOC_ROOT", split='TEST')

    # 5. dataloader
    train_loader = DataLoader(dataset=train_set,
                              batch_size=opts.batch_size,
                              collate_fn=train_set.collate_fn,
                              shuffle=True,
                              pin_memory=True,
                              num_workers=opts.num_workers)

    test_loader = DataLoader(dataset=test_set,
                             batch_size=1,
                             collate_fn=test_set.collate_fn,
                             shuffle=False)

    # 6. model
    model = YOLO_VGG_16().to(device)

    # 7. criterion
    criterion = Yolo_Loss(num_classes=20)
    # 8. optimizer
    optimizer = optim.SGD(params=model.parameters(),
                          lr=opts.lr,
                          momentum=0.9,
                          weight_decay=5e-4)

    # 9. scheduler
    scheduler = StepLR(optimizer=optimizer, step_size=100, gamma=0.1)
    scheduler = None

    # 10. resume
    if opts.start_epoch != 0:

        checkpoint = torch.load(
            os.path.join(opts.save_path, opts.save_file_name) +
            '.{}.pth.tar'.format(opts.start_epoch - 1))  # train
        model.load_state_dict(
            checkpoint['model_state_dict'])  # load model state dict
        optimizer.load_state_dict(
            checkpoint['optimizer_state_dict'])  # load optim state dict
        if scheduler is not None:
            scheduler.load_state_dict(
                checkpoint['scheduler_state_dict'])  # load sched state dict
        print('\nLoaded checkpoint from epoch %d.\n' %
              (int(opts.start_epoch) - 1))

    else:

        print('\nNo check point to resume.. train from scratch.\n')

    # 11. train
    for epoch in range(opts.start_epoch, opts.epochs):

        train(epoch=epoch,
              device=device,
              vis=vis,
              train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              scheduler=scheduler,
              save_path=opts.save_path,
              save_file_name=opts.save_file_name)

        if scheduler is not None:
            scheduler.step()

        # 12. test
        test(epoch=epoch,
             device=device,
             vis=vis,
             test_loader=test_loader,
             model=model,
             criterion=criterion,
             save_path=opts.save_path,
             save_file_name=opts.save_file_name,
             conf_thres=opts.conf_thres,
             eval=True)
示例#12
0
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    # while(1):
    #     a=2
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    traindir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                            'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                          'ILSVRC2012_img_val_rec')
    train_loader, val_loader = getTrainValDataset(traindir, valdir,
                                                  batch_sizes, 100, num_gpu,
                                                  num_workers)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')

    model_t = ResNet50()

    # model_kd = resnet101(pretrained=False)

    #print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t

    model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = ResNet50_sprase().to(args.gpus[0])
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    #ckpt_kd = torch.load('resnet101-5d3b4d8f.pth', map_location=torch.device(f"cuda:{args.gpus[0]}"))
    #state_dict_kd = ckpt_kd
    #new_state_dict_kd = state_dict_kd
    #model_kd.load_state_dict(new_state_dict_kd)
    #model_kd = model_kd.to(args.gpus[0])

    #for para in list(model_kd.parameters())[:-2]:
    #para.requires_grad = False

    model_d = Discriminator().to(args.gpus[0])

    model_s = nn.DataParallel(model_s).cuda()
    model_t = nn.DataParallel(model_t).cuda()
    model_d = nn.DataParallel(model_d).cuda()

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr * 100, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_s = ckpt['state_dict_s']
        state_dict_d = ckpt['state_dict_d']

        new_state_dict_s = OrderedDict()
        for k, v in state_dict_s.items():
            new_state_dict_s['module.' + k] = v

        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(new_state_dict_s)
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        start_epoch = ckpt['epoch']
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    models = [model_t, model_s, model_d]  #, model_kd]
    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        #global g_e
        #g_e = epoch
        #gl.set_value('epoch',g_e)

        train(args, train_loader, models, optimizers, epoch, writer_train)
        test_prec1, test_prec5 = test(args, val_loader, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        train_loader.reset()
        val_loader.reset()
        #if is_best:
        checkpoint.save_model(state, epoch + 1, is_best)
        #checkpoint.save_model(state, 1, False)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
示例#13
0
def main():

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = import_module(f'model.{args.arch}').__dict__[args.teacher_model]().to(device)

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir, map_location=device)
    

    if args.arch == 'densenet':
        state_dict_t = {}
        for k, v in ckpt_t['state_dict'].items():
            new_key = '.'.join(k.split('.')[1:])
            if new_key == 'linear.weight':
                new_key = 'fc.weight'
            elif new_key == 'linear.bias':
                new_key = 'fc.bias'
            state_dict_t[new_key] = v
    else:
        state_dict_t = ckpt_t['state_dict']


    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(device)

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = import_module(f'model.{args.arch}').__dict__[args.student_model]().to(device)

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(device) 

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    param_s = [param for name, param in model_s.named_parameters() if 'mask' not in name]
    param_m = [param for name, param in model_s.named_parameters() if 'mask' in name]

    optimizer_s = optim.SGD(param_s, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume, map_location=device)
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']

        model_s.load_state_dict(ckpt['            state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))


    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]
    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print_logger.info(f"Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt', map_location=device)

    model = import_module('utils.preprocess').__dict__[f'{args.arch}'](args, best_model['state_dict_s'])
示例#14
0
def train(data_dir, model_dir, checkpoint_path, pretrained_dvector_path,
          n_steps, save_every, decay_every, seg_len, ratio):
    """Train speaker verifier"""

    # setup
    total_steps = 0
    assert os.path.isdir(model_dir)

    # load data
    dataset = SVDataset(data_dir, seg_len)
    train_index = sample_index(len(dataset), ratio)
    valid_index = [x for x in range(len(dataset)) if x not in train_index]
    train_set = Subset(dataset, train_index)
    valid_set = Subset(dataset, valid_index)
    train_loader = DataLoader(train_set, batch_size=1024, shuffle=True,
                              collate_fn=pad_batch_with_label, drop_last=False)
    valid_loader = DataLoader(valid_set, batch_size=2, shuffle=False,
                              collate_fn=pad_batch_with_label, drop_last=False)
    train_loader_iter = iter(train_loader)
    print(f"Training starts with {train_set.dataset.total} speakers.")

    # load checkpoint
    ckpt = None
    if checkpoint_path is not None:
        ckpt = torch.load(checkpoint_path)
        dvector_path = ckpt["dvector_path"]

    # build network and training tools
    model = SpeakerVerifier(pretrained_dvector_path, dataset.total)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters())
    scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5)
    if ckpt is not None:
        total_steps = ckpt["total_steps"]
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optmizier"])
        scheduler.load_state_dict(ckpt["scheduler"])

    # prepare for traning
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    writer = SummaryWriter(model_dir)
    pbar = tqdm.trange(n_steps)

    # start training
    for step in pbar:

        total_steps += 1

        try:
            batch = next(train_loader_iter)
        except StopIteration:
            train_loader_iter = iter(train_loader)
            batch = next(train_loader_iter)

        data, label = batch
        logits = model((data.to(device)))

        loss = criterion(logits, torch.LongTensor(label).to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_description(f"global = {total_steps}, loss = {loss:.4f}")
        writer.add_scalar("train_loss", loss, total_steps)

        if (step + 1) % save_every == 0:
            ckpt_path = os.path.join(model_dir, f"ckpt-{total_steps}.tar")
            ckpt_dict = {
                "total_steps": total_steps,
                "dvector_path": dvector_path,
                "state_dict": model.state_dict(),
                "criterion": criterion.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }
            torch.save(ckpt_dict, ckpt_path)

        if (step + 1) % save_every == 0:
            val_acc = 0.0
            val_loss = 0.0
            for batch in valid_loader:
                data, label = batch
                with torch.no_grad():
                    logits = model(data.to(device))
                    pred = logits.argmax(dim=1)
                    val_acc += (pred ==
                                torch.LongTensor(label).to(device)).sum().item()
                    val_loss += criterion(logits,
                                          torch.LongTensor(label).to(device)).item()
            val_acc /= len(valid_set)
            val_loss /= len(valid_loader)
            writer.add_scalar("valid_accuracy", val_acc, total_steps)
            writer.add_scalar("valid_loss", val_loss, total_steps)

    print("Training completed.")
示例#15
0
class DeployedESTransformer(object):
    def __init__(
            self,
            max_epochs=15,
            batch_size=1,
            batch_size_test=64,
            freq_of_test=-1,
            learning_rate=1e-3,
            lr_scheduler_step_size=9,
            lr_decay=0.9,
            per_series_lr_multip=1.0,
            gradient_eps=1e-8,
            transformer_weight_decay=0,
            noise_std=0.001,
            level_variability_penalty=80,
            testing_percentile=50,
            training_percentile=50,
            ensemble=False,
            seasonality=[4],
            input_size=4,
            output_size=8,
            frequency=None,
            max_periods=20,
            random_seed=1,
            device='cpu',
            root_dir='./',

            # Transformer parameters
            d_input=4,
            d_model=48,
            d_output=6,
            q=8,
            v=8,
            h=4,
            N=4,
            attention_size=None,
            dropout=0.3,
            chunk_mode='chunk',
            pe=None,
            pe_period=24,
            dataset_name=None):

        super().__init__()
        self.mc = ModelConfig(
            max_epochs=max_epochs,
            batch_size=batch_size,
            batch_size_test=batch_size_test,
            freq_of_test=freq_of_test,
            learning_rate=learning_rate,
            lr_scheduler_step_size=lr_scheduler_step_size,
            lr_decay=lr_decay,
            per_series_lr_multip=per_series_lr_multip,
            gradient_eps=gradient_eps,
            transformer_weight_decay=transformer_weight_decay,
            noise_std=noise_std,
            level_variability_penalty=level_variability_penalty,
            testing_percentile=testing_percentile,
            training_percentile=training_percentile,
            ensemble=ensemble,
            seasonality=seasonality,
            input_size=input_size,
            output_size=output_size,
            frequency=frequency,
            max_periods=max_periods,
            random_seed=random_seed,
            device=device,
            root_dir=root_dir,
            d_input=d_input,
            d_model=d_model,
            d_output=d_output,
            q=q,
            v=v,
            h=h,
            N=N,
            attention_size=attention_size,
            dropout=dropout,
            chunk_mode=chunk_mode,
            pe=pe,
            pe_period=pe_period)
        self.device = device
        self.dataset_name = dataset_name
        self._fitted = False

    def instantiate_estransformer(self, exogenous_size, n_series):
        self.mc.exogenous_size = exogenous_size
        self.mc.n_series = n_series
        self.estransformer = ESTransformer(self.mc).to(self.mc.device)

    def fit(self,
            X_df,
            y_df,
            X_test_df=None,
            y_test_df=None,
            y_hat_benchmark='y_hat_naive2',
            warm_start=False,
            shuffle=True,
            verbose=True):
        # Transform long dfs to wide numpy
        assert type(X_df) == pd.core.frame.DataFrame
        assert type(y_df) == pd.core.frame.DataFrame
        assert all([(col in X_df) for col in ['unique_id', 'ds', 'x']])
        assert all([(col in y_df) for col in ['unique_id', 'ds', 'y']])
        if y_test_df is not None:
            assert y_hat_benchmark in y_test_df.columns, 'benchmark is not present in y_test_df, use y_hat_benchmark to define it'

        # Storing dfs for OWA evaluation, initializing min_owa
        self.y_train_df = y_df
        self.X_test_df = X_test_df
        self.y_test_df = y_test_df
        self.min_owa = 4.0
        self.min_epoch = 0

        self.int_ds = isinstance(self.y_train_df['ds'][0],
                                 (int, np.int, np.int64))

        self.y_hat_benchmark = y_hat_benchmark

        X, y = self.long_to_wide(X_df, y_df)
        assert len(X) == len(y)
        assert X.shape[1] >= 3

        # Exogenous variables
        unique_categories = np.unique(X[:, 1])
        self.mc.category_to_idx = dict(
            (word, index) for index, word in enumerate(unique_categories))
        exogenous_size = len(unique_categories)

        # Create batches (device in mc)
        self.train_dataloader = Iterator(mc=self.mc, X=X, y=y)

        # Random Seeds (model initialization)
        torch.manual_seed(self.mc.random_seed)
        np.random.seed(self.mc.random_seed)

        # Initialize model
        n_series = self.train_dataloader.n_series
        self.instantiate_estransformer(exogenous_size, n_series)

        # Validating frequencies
        X_train_frequency = pd.infer_freq(X_df.head()['ds'])
        y_train_frequency = pd.infer_freq(y_df.head()['ds'])
        self.frequencies = [X_train_frequency, y_train_frequency]

        if (X_test_df is not None) and (y_test_df is not None):
            X_test_frequency = pd.infer_freq(X_test_df.head()['ds'])
            y_test_frequency = pd.infer_freq(y_test_df.head()['ds'])
            self.frequencies += [X_test_frequency, y_test_frequency]

        assert len(set(self.frequencies)) <= 1, \
          "Match the frequencies of the dataframes {}".format(self.frequencies)

        self.mc.frequency = self.frequencies[0]
        print("Infered frequency: {}".format(self.mc.frequency))

        # Train model
        self._fitted = True
        self.train(dataloader=self.train_dataloader,
                   max_epochs=self.mc.max_epochs,
                   warm_start=warm_start,
                   shuffle=shuffle,
                   verbose=verbose)

    def train(self,
              dataloader,
              max_epochs,
              warm_start=False,
              shuffle=True,
              verbose=True):

        if self.mc.ensemble:
            self.estransformer_ensemble = [
                deepcopy(self.estransformer).to(self.mc.device)
            ] * 5
        if verbose:
            print(15 * '=' + ' Training ESTransformer  ' + 15 * '=' + '\n')

        # Model parameters
        es_parameters = filter(lambda p: p.requires_grad,
                               self.estransformer.es.parameters())
        params = sum([np.prod(p.size()) for p in es_parameters])
        print('Number of parameters of ES: ', params)

        trans_parameters = filter(lambda p: p.requires_grad,
                                  self.estransformer.transformer.parameters())
        params = sum([np.prod(p.size()) for p in trans_parameters])
        print('Number of parameters of Transformer: ', params)

        # Optimizers
        if not warm_start:
            self.es_optimizer = optim.Adam(
                params=self.estransformer.es.parameters(),
                lr=self.mc.learning_rate * self.mc.per_series_lr_multip,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps)

            self.es_scheduler = StepLR(
                optimizer=self.es_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=0.9)

            self.transformer_optimizer = optim.Adam(
                params=self.estransformer.transformer.parameters(),
                lr=self.mc.learning_rate,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps,
                weight_decay=self.mc.transformer_weight_decay)

            self.transformer_scheduler = StepLR(
                optimizer=self.transformer_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=self.mc.lr_decay)

        all_epoch = []
        all_train_loss = []
        all_test_loss = []

        # Loss Functions
        train_tau = self.mc.training_percentile / 100
        train_loss = SmylLoss(
            tau=train_tau,
            level_variability_penalty=self.mc.level_variability_penalty)

        eval_tau = self.mc.testing_percentile / 100
        eval_loss = PinballLoss(tau=eval_tau)

        for epoch in range(max_epochs):
            self.estransformer.train()
            start = time.time()
            if shuffle: dataloader.shuffle_dataset(random_seed=epoch)
            losses = []
            for j in range(dataloader.n_batches):
                self.es_optimizer.zero_grad()
                self.transformer_optimizer.zero_grad()

                batch = dataloader.get_batch()
                windows_y, windows_y_hat, levels = self.estransformer(batch)

                # Pinball loss on normalized values
                loss = train_loss(windows_y, windows_y_hat, levels)
                losses.append(loss.data.cpu().numpy())
                loss.backward()
                self.transformer_optimizer.step()
                self.es_optimizer.step()

            # Decay learning rate
            self.es_scheduler.step()
            self.transformer_scheduler.step()

            if self.mc.ensemble:
                copy_estransformer = deepcopy(self.estransformer)
                copy_estransformer.eval()
                self.estransformer_ensemble.pop(0)
                self.estransformer_ensemble.append(copy_estransformer)

            # Evaluation
            self.train_loss = np.mean(losses)
            if verbose:
                print("========= Epoch {} finished =========".format(epoch))
                print("Training time: {}".format(round(time.time() - start,
                                                       5)))
                print("Training loss ({} prc): {:.5f}".format(
                    self.mc.training_percentile, self.train_loss))
                self.test_loss = self.model_evaluation(dataloader, eval_loss)
                print("Testing loss  ({} prc): {:.5f}".format(
                    self.mc.testing_percentile, self.test_loss))
                self.evaluate_model_prediction(self.y_train_df,
                                               self.X_test_df,
                                               self.y_test_df,
                                               self.y_hat_benchmark,
                                               epoch=epoch)
                self.estransformer.train()

                all_epoch.append(epoch)
                all_train_loss.append(self.train_loss)
                all_test_loss.append(self.test_loss)

                converge = pd.DataFrame({
                    'Epoch': all_epoch,
                    'Train loss': all_train_loss,
                    'Test loss': all_test_loss
                })
                # converge.to_csv("D:\\Sang\\hybcast\\hybcast3\\" + self.dataset_name + 'log_' + self.dataset_name +'.csv', index=False)

            if (epoch % 100 == 0) or (epoch % 499 == 0):
                # self.save(model_dir="D:\\Sang\\hybcast\\hybcast3\\" + self.dataset_name +'\\model\\', epoch=epoch)
                None

        if verbose: print('Train finished! \n')

    def predict(self, X_df, decomposition=False):
        assert type(X_df) == pd.core.frame.DataFrame
        assert 'unique_id' in X_df
        assert self._fitted, "Model not fitted yet"

        self.estransformer.eval()

        # Create fast dataloader
        if self.mc.n_series < self.mc.batch_size_test:
            new_batch_size = self.mc.n_series
        else:
            new_batch_size = self.mc.batch_size_test
        self.train_dataloader.update_batch_size(new_batch_size)
        dataloader = self.train_dataloader

        # Create Y_hat_panel placeholders
        output_size = self.mc.output_size
        n_unique_id = len(dataloader.sort_key['unique_id'])
        panel_unique_id = pd.Series(
            dataloader.sort_key['unique_id']).repeat(output_size)

        #access column with last train date
        panel_last_ds = pd.Series(dataloader.X[:, 2])
        panel_ds = []
        for i in range(len(panel_last_ds)):
            ranges = pd.date_range(start=panel_last_ds[i],
                                   periods=output_size + 1,
                                   freq=self.mc.frequency)
            panel_ds += list(ranges[1:])

        panel_y_hat = np.zeros((output_size * n_unique_id))

        # Predict
        count = 0
        for j in range(dataloader.n_batches):
            batch = dataloader.get_batch()
            batch_size = batch.y.shape[0]

            if self.mc.ensemble:
                y_hat = torch.zeros((5, batch_size, output_size))
                for i in range(5):
                    y_hat[i, :, :] = self.estransformer_ensemble[i].predict(
                        batch)
                y_hat = torch.mean(y_hat, 0)
            else:
                y_hat = self.estransformer.predict(batch)

            y_hat = y_hat.data.cpu().numpy()

            panel_y_hat[count:count +
                        output_size * batch_size] = y_hat.flatten()
            count += output_size * batch_size

        Y_hat_panel_dict = {
            'unique_id': panel_unique_id,
            'ds': panel_ds,
            'y_hat': panel_y_hat
        }

        assert len(panel_ds) == len(panel_y_hat) == len(panel_unique_id)

        Y_hat_panel = pd.DataFrame.from_dict(Y_hat_panel_dict)

        if 'ds' in X_df:
            Y_hat_panel = X_df.merge(Y_hat_panel,
                                     on=['unique_id', 'ds'],
                                     how='left')
        else:
            Y_hat_panel = X_df.merge(Y_hat_panel, on=['unique_id'], how='left')

        self.train_dataloader.update_batch_size(self.mc.batch_size)
        return Y_hat_panel

    def per_series_evaluation(self, dataloader, criterion):
        with torch.no_grad():
            # Create fast dataloader
            if self.mc.n_series < self.mc.batch_size_test:
                new_batch_size = self.mc.n_series
            else:
                new_batch_size = self.mc.batch_size_test
            dataloader.update_batch_size(new_batch_size)

            per_series_losses = []
            for j in range(dataloader.n_batches):
                batch = dataloader.get_batch()
                windows_y, windows_y_hat, _ = self.estransformer(batch)
                loss = criterion(windows_y, windows_y_hat)
                per_series_losses += loss.data.cpu().numpy().tolist()

            dataloader.update_batch_size(self.mc.batch_size)
        return per_series_losses

    def model_evaluation(self, dataloader, criterion):
        with torch.no_grad():
            # Create fast dataloader
            if self.mc.n_series < self.mc.batch_size_test:
                new_batch_size = self.mc.n_series
            else:
                new_batch_size = self.mc.batch_size_test
            dataloader.update_batch_size(new_batch_size)

            model_loss = 0.0
            for j in range(dataloader.n_batches):
                batch = dataloader.get_batch()
                windows_y, windows_y_hat, _ = self.estransformer(batch)
                loss = criterion(windows_y, windows_y_hat)
                model_loss += loss.data.cpu().numpy()

            model_loss /= dataloader.n_batches
            dataloader.update_batch_size(self.mc.batch_size)
        return model_loss

    def evaluate_model_prediction(self,
                                  y_train_df,
                                  X_test_df,
                                  y_test_df,
                                  y_hat_benchmark='y_hat_naive2',
                                  epoch=None):
        assert self._fitted, "Model not fitted yet"

        y_panel = y_test_df.filter(['unique_id', 'ds', 'y'])
        y_benchmark_panel = y_test_df.filter(
            ['unique_id', 'ds', y_hat_benchmark])
        y_benchmark_panel.rename(columns={y_hat_benchmark: 'y_hat'},
                                 inplace=True)
        y_hat_panel = self.predict(X_test_df)
        y_insample = y_train_df.filter(['unique_id', 'ds', 'y'])

        model_owa, model_mase, model_smape = owa(
            y_panel,
            y_hat_panel,
            y_benchmark_panel,
            y_insample,
            seasonality=self.mc.naive_seasonality)

        if self.min_owa > model_owa:
            self.min_owa = model_owa
            if epoch is not None: self.min_epoch = epoch

        print('OWA: {} '.format(np.round(model_owa, 3)))
        print('SMAPE: {} '.format(np.round(model_smape, 3)))
        print('MASE: {} '.format(np.round(model_mase, 3)))

        return model_owa, model_mase, model_smape

    def long_to_wide(self, X_df, y_df):
        data = X_df.copy()
        data['y'] = y_df['y'].copy()
        sorted_ds = np.sort(data['ds'].unique())
        ds_map = {}
        for dmap, t in enumerate(sorted_ds):
            ds_map[t] = dmap
        data['ds_map'] = data['ds'].map(ds_map)
        data = data.sort_values(by=['ds_map', 'unique_id'])
        df_wide = data.pivot(index='unique_id', columns='ds_map')['y']

        x_unique = data[['unique_id', 'x']].groupby('unique_id').first()
        last_ds = data[['unique_id', 'ds']].groupby('unique_id').last()
        assert len(x_unique) == len(data.unique_id.unique())
        df_wide['x'] = x_unique
        df_wide['last_ds'] = last_ds
        df_wide = df_wide.reset_index().rename_axis(None, axis=1)

        ds_cols = data.ds_map.unique().tolist()
        X = df_wide.filter(items=['unique_id', 'x', 'last_ds']).values
        y = df_wide.filter(items=ds_cols).values

        return X, y

    def get_dir_name(self, root_dir=None):
        if not root_dir:
            assert self.mc.root_dir
            root_dir = self.mc.root_dir

        data_dir = self.mc.dataset_name
        model_parent_dir = os.path.join(root_dir, data_dir)
        model_path = ['estransformer_{}'.format(str(self.mc.copy))]
        model_dir = os.path.join(model_parent_dir, '_'.join(model_path))
        return model_dir

    def save(self, model_dir=None, copy=None, epoch=None):
        if copy is not None:
            self.mc.copy = copy

        if not model_dir:
            assert self.mc.root_dir
            model_dir = self.get_dir_name()

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        print('Saving model to:\n {}'.format(model_dir) + '\n')
        torch.save(
            {
                'model_state_dict': self.estransformer.state_dict(),
                'es_optimizer': self.es_optimizer.state_dict(),
                'es_scheduler': self.es_scheduler.state_dict(),
                'transformer_optimizer':
                self.transformer_optimizer.state_dict(),
                'transformer_scheduler':
                self.transformer_scheduler.state_dict(),
                'epoch': epoch
            },
            model_dir + 'model_epoch_' + str(epoch) + '_' + self.dataset_name)

    def load(self, model_dir=None, copy=None, conti_train=False):
        # Run preprocess to instantialize estransformer and its optimizer
        if copy is not None:
            self.mc.copy = copy

        if not model_dir:
            assert self.mc.root_dir
            model_dir = self.get_dir_name()

        temp_model = torch.load(model_dir,
                                map_location=torch.device(self.device))

        # Load model
        self.estransformer.load_state_dict(temp_model['model_state_dict'])

        if conti_train:
            # Instantiate optimizer and scheduler
            self.es_optimizer = optim.Adam(
                params=self.estransformer.es.parameters(),
                lr=self.mc.learning_rate * self.mc.per_series_lr_multip,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps)

            self.es_scheduler = StepLR(
                optimizer=self.es_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=0.9)

            self.transformer_optimizer = optim.Adam(
                params=self.estransformer.transformer.parameters(),
                lr=self.mc.learning_rate,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps,
                weight_decay=self.mc.transformer_weight_decay)

            self.transformer_scheduler = StepLR(
                optimizer=self.transformer_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=self.mc.lr_decay)

            # Load state
            self.es_optimizer.load_state_dict(temp_model['es_optimizer'])
            self.es_scheduler.load_state_dict(temp_model['es_scheduler'])
            self.transformer_optimizer.load_state_dict(
                temp_model['transformer_optimizer'])
            self.transformer_scheduler.load_state_dict(
                temp_model['transformer_scheduler'])
            self.min_epoch = temp_model['epoch']

            self.train(dataloader=self.train_dataloader,
                       max_epochs=self.mc.max_epochs,
                       warm_start=True,
                       shuffle=True,
                       verbose=True)

    def preprocess(self,
                   X_df,
                   y_df,
                   X_test_df=None,
                   y_test_df=None,
                   y_hat_benchmark='y_hat_naive2',
                   warm_start=False,
                   shuffle=True,
                   verbose=True):
        # Transform long dfs to wide numpy
        assert type(X_df) == pd.core.frame.DataFrame
        assert type(y_df) == pd.core.frame.DataFrame
        assert all([(col in X_df) for col in ['unique_id', 'ds', 'x']])
        assert all([(col in y_df) for col in ['unique_id', 'ds', 'y']])
        if y_test_df is not None:
            assert y_hat_benchmark in y_test_df.columns, 'benchmark is not present in y_test_df, use y_hat_benchmark to define it'

        # Storing dfs for OWA evaluation, initializing min_owa
        self.y_train_df = y_df
        self.X_test_df = X_test_df
        self.y_test_df = y_test_df
        self.min_owa = 4.0
        self.min_epoch = 0

        self.int_ds = isinstance(self.y_train_df['ds'][0],
                                 (int, np.int, np.int64))

        self.y_hat_benchmark = y_hat_benchmark

        X, y = self.long_to_wide(X_df, y_df)
        assert len(X) == len(y)
        assert X.shape[1] >= 3

        # Exogenous variables
        unique_categories = np.unique(X[:, 1])
        self.mc.category_to_idx = dict(
            (word, index) for index, word in enumerate(unique_categories))
        exogenous_size = len(unique_categories)

        # Create batches (device in mc)
        self.train_dataloader = Iterator(mc=self.mc, X=X, y=y)

        # Random Seeds (model initialization)
        torch.manual_seed(self.mc.random_seed)
        np.random.seed(self.mc.random_seed)

        # Initialize model
        n_series = self.train_dataloader.n_series

        self.instantiate_estransformer(exogenous_size, n_series)

        # Validating frequencies
        X_train_frequency = pd.infer_freq(X_df.head()['ds'])
        y_train_frequency = pd.infer_freq(y_df.head()['ds'])
        self.frequencies = [X_train_frequency, y_train_frequency]

        if (X_test_df is not None) and (y_test_df is not None):
            X_test_frequency = pd.infer_freq(X_test_df.head()['ds'])
            y_test_frequency = pd.infer_freq(y_test_df.head()['ds'])
            self.frequencies += [X_test_frequency, y_test_frequency]

        assert len(set(self.frequencies)) <= 1, \
          "Match the frequencies of the dataframes {}".format(self.frequencies)

        self.mc.frequency = self.frequencies[0]
        print("Infered frequency: {}".format(self.mc.frequency))

        # Train model
        self._fitted = True
示例#16
0
                                 lr=modelparams['lr_psi'],
                                 amsgrad=True)
optimizer_design = torch.optim.Adam([d], lr=modelparams['lr_d'], amsgrad=True)
scheduler_psi = StepLR(optimizer_psi,
                       step_size=modelparams['step_psi'],
                       gamma=modelparams['gamma_psi'])
scheduler_design = StepLR(optimizer_design,
                          step_size=modelparams['step_d'],
                          gamma=modelparams['gamma_d'])

if RELOAD:

    # load in optimizer state dicts
    optimizer_psi.load_state_dict(meta_info['optimizer_psi_state'])
    optimizer_design.load_state_dict(meta_info['optimizer_design_state'])
    scheduler_psi.load_state_dict(meta_info['scheduler_psi_state'])
    scheduler_design.load_state_dict(meta_info['scheduler_design_state'])

    del meta_info

# --- TRAINING --- #

num_params = sum([
    np.prod(p.size())
    for p in filter(lambda p: p.requires_grad, model.parameters())
])
print("Number of trainable parameters", num_params)

# initialize dataset
dataset = SIRDatasetPE_obs(d, data['prior_samples'], device)
示例#17
0
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd, cfg):

        self.cfg_data = cfg_data

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd
        self.cfg = cfg

        self.net_name = cfg.NET

        self.net = CrowdCounter(cfg.GPU_ID, self.net_name, DA=True).cuda()

        self.num_parameters = sum(
            [param.nelement() for param in self.net.parameters()])
        print('num_parameters:', self.num_parameters)
        self.optimizer = optim.Adam(self.net.CCN.parameters(),
                                    lr=cfg.LR,
                                    weight_decay=1e-4)
        #         self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer,
                                step_size=cfg.NUM_EPOCH_LR_DECAY,
                                gamma=cfg.LR_DECAY)

        self.train_record = {
            'best_mae': 1e20,
            'best_mse': 1e20,
            'best_model_name': '_'
        }

        self.hparam = {
            'lr': cfg.LR,
            'n_epochs': cfg.MAX_EPOCH,
            'number of parameters': self.num_parameters,
            'dataset': cfg.DATASET
        }  # ,'finetuned':cfg.FINETUNE}
        self.timer = {
            'iter time': Timer(),
            'train time': Timer(),
            'val time': Timer()
        }

        self.epoch = 0
        self.i_tb = 0
        '''discriminator'''
        if cfg.GAN == 'Vanilla':
            self.bce_loss = torch.nn.BCELoss()
        elif cfg.GAN == 'LS':
            self.bce_loss = torch.nn.MSELoss()

        if cfg.NET == 'Res50':
            self.channel1, self.channel2 = 1024, 128

        self.D = [
            FCDiscriminator(self.channel1, self.bce_loss).cuda(),
            FCDiscriminator(self.channel2, self.bce_loss).cuda()
        ]
        self.D[0].apply(weights_init())
        self.D[1].apply(weights_init())

        self.dis = self.cfg.DIS

        self.d_opt = [
            optim.Adam(self.D[0].parameters(),
                       lr=self.cfg.D_LR,
                       betas=(0.9, 0.99)),
            optim.Adam(self.D[1].parameters(),
                       lr=self.cfg.D_LR,
                       betas=(0.9, 0.99))
        ]

        self.scheduler_D = [
            StepLR(self.d_opt[0],
                   step_size=cfg.NUM_EPOCH_LR_DECAY,
                   gamma=cfg.LR_DECAY),
            StepLR(self.d_opt[1],
                   step_size=cfg.NUM_EPOCH_LR_DECAY,
                   gamma=cfg.LR_DECAY)
        ]
        '''loss and lambdas here'''
        self.lambda_adv = [cfg.LAMBDA_ADV1, cfg.LAMBDA_ADV2]

        if cfg.PRE_GCC:
            print('===================Loaded Pretrained GCC================')
            weight = torch.load(cfg.PRE_GCC_MODEL)['net']
            #             weight=torch.load(cfg.PRE_GCC_MODEL)
            try:
                self.net.load_state_dict(convert_state_dict_gcc(weight))
            except:
                self.net.load_state_dict(weight)
        #             self.net=torch.nn.DataParallel(self.net, device_ids=cfg.GPU_ID).cuda()
        '''modify dataloader'''
        self.source_loader, self.target_loader, self.test_loader, self.restore_transform = dataloader(
        )
        self.source_len = len(self.source_loader.dataset)
        self.target_len = len(self.target_loader.dataset)
        print("source:", self.source_len)
        print("target:", self.target_len)
        self.source_loader_iter = cycle(self.source_loader)
        self.target_loader_iter = cycle(self.target_loader)

        if cfg.RESUME:
            print('===================Loaded model to resume================')
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']
        self.writer, self.log_txt = logger(self.exp_path,
                                           self.exp_name,
                                           self.pwd,
                                           'exp',
                                           self.source_loader,
                                           self.test_loader,
                                           resume=cfg.RESUME,
                                           cfg=cfg)

    def forward(self):
        print('forward!!')
        # self.validate_V3()
        with open(self.log_txt, 'a') as f:
            f.write(str(self.net) + '\n')
            f.write('num_parameters:' + str(self.num_parameters) + '\n')

        for epoch in range(self.epoch, self.cfg.MAX_EPOCH):
            self.epoch = epoch

            # training
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            if epoch > self.cfg.LR_DECAY_START:
                self.scheduler.step()
                self.scheduler_D[0].step()
                self.scheduler_D[1].step()

            print('train time: {:.2f}s'.format(self.timer['train time'].diff))
            print('=' * 20)
            self.net.eval()

            # validation
            if epoch % self.cfg.VAL_FREQ == 0 or epoch > self.cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50', 'Mall']:
                    self.validate_V1()
                elif self.data_mode is 'WE':
                    self.validate_V2()
                elif self.data_mode is 'GCC':
                    self.validate_V3()
                elif self.data_mode is 'NTU':
                    self.validate_V4()
                # self.validate_train()
                self.timer['val time'].toc(average=False)
                print('val time: {:.2f}s'.format(self.timer['val time'].diff))

    def train(self):  # training for all datasets
        self.net.train()

        for i in range(max(len(self.source_loader), len(self.target_loader))):
            torch.cuda.empty_cache()
            self.timer['iter time'].tic()
            img, gt_img = self.source_loader_iter.__next__()
            tar, gt_tar = self.target_loader_iter.__next__()

            img = Variable(img).cuda()
            gt_img = Variable(gt_img).cuda()

            tar = Variable(tar).cuda()
            gt_tar = Variable(gt_tar).cuda()

            #gen loss
            # loss, loss_adv, pred, pred1, pred2, pred_tar, pred_tar1, pred_tar2 = self.gen_update(img,tar,gt_img,gt_tar)
            self.optimizer.zero_grad()

            for param in self.D[0].parameters():
                param.requires_grad = False
            for param in self.D[1].parameters():
                param.requires_grad = False

            # source
            pred = self.net(img, gt_img)
            loss = self.net.loss
            if not self.cfg.LOSS_TOG:
                loss.backward()

            # target
            pred_tar = self.net(tar, gt_tar)

            loss_adv = self.D[self.dis].cal_loss(pred_tar[self.dis],
                                                 0) * self.lambda_adv[self.dis]

            if not self.cfg.LOSS_TOG:
                loss_adv.backward()
            else:
                loss += loss_adv
                loss.backward()

            #dis loss
            loss_d = self.dis_update(pred, pred_tar)
            self.d_opt[0].step()
            self.d_opt[1].step()

            self.optimizer.step()

            if (i + 1) % self.cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.writer.add_scalar('loss_adv', loss_adv.item(), self.i_tb)
                self.writer.add_scalar('loss_d', loss_d.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)

                print('[ep %d][it %d][loss %.4f][loss_adv %.8f][loss_d %.4f][lr %.8f][%.2fs]' % \
                      (self.epoch + 1, i + 1, loss.item(), loss_adv.item() if loss_adv else 0, loss_d.item(), self.optimizer.param_groups[0]['lr'],
                       self.timer['iter time'].diff))
                print('        [cnt: gt: %.1f pred: %.2f]' %
                      (gt_img[0].sum().data / self.cfg_data.LOG_PARA,
                       pred[-1][0].sum().data / self.cfg_data.LOG_PARA))

                print('        [tar: gt: %.1f pred: %.2f]' %
                      (gt_tar[0].sum().data / self.cfg_data.LOG_PARA,
                       pred_tar[-1][0].sum().data / self.cfg_data.LOG_PARA))

        self.writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'],
                               self.epoch + 1)

    def gen_update(self, img, tar, gt_img, gt_tar):
        pass
        # return loss,loss_adv,pred,pred1,pred2,pred_tar,pred_tar1,pred_tar2

    def dis_update(self, pred, pred_tar):
        self.d_opt[self.dis].zero_grad()

        for param in self.D[0].parameters():
            param.requires_grad = True
        for param in self.D[1].parameters():
            param.requires_grad = True

        #source
        pred = [pred[0].detach(), pred[1].detach()]

        loss_d = self.D[self.dis].cal_loss(pred[self.dis], 0)
        if not self.cfg.LOSS_TOG:
            loss_d.backward()

        loss_D = loss_d

        #target
        pred_tar = [pred_tar[0].detach(), pred_tar[1].detach()]

        loss_d = self.D[self.dis].cal_loss(pred_tar[self.dis], 1)
        if not self.cfg.LOSS_TOG:
            loss_d.backward()

        loss_D += loss_d

        if self.cfg.LOSS_TOG:
            loss_D.backward()

        return loss_D

    def validate_train(self):
        self.net.eval()
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for img, gt_map in self.source_loader:

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                _, _, pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map[i_img]) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img]) / self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count - pred_cnt)
                    s_mse = (gt_count - pred_cnt) * (gt_count - pred_cnt)

                    losses.update(self.net.loss.item())
                    maes.update(s_mae)
                    mses.update(s_mse)

        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)

        print("test on source domain")
        print_NTU_summary(self.log_txt, self.epoch, [mae, mse, loss],
                          self.train_record)

    def validate_V4(self):  # validate_V4 for NTU
        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in enumerate(self.test_loader, 0):

            img, gt_map = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                _, _, pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map[i_img]) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img]) / self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count - pred_cnt)
                    s_mse = (gt_count - pred_cnt) * (gt_count - pred_cnt)

                    losses.update(self.net.loss.item())
                    maes.update(s_mae)
                    mses.update(s_mse)

                if vi == 0:
                    vis_results(self.exp_name, self.epoch, self.writer,
                                self.restore_transform, img, pred_map, gt_map)

        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net, self.optimizer,
                                         self.scheduler, self.epoch, self.i_tb,
                                         self.exp_path, self.exp_name,
                                         [mae, mse, loss], self.train_record,
                                         False, self.log_txt)

        print_NTU_summary(self.log_txt, self.epoch, [mae, mse, loss],
                          self.train_record)
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True,
        seq_mode=True)  # seq_mode implemented

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 104965071),  #(99965071, 124950714), # 20%를 테스트
        batch_size=TS_BATCH_SZ,
        shuffle=False,
        seq_mode=True)

    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.7)

    CF_model = MLP_Regressor().cuda(GPU)
    CF_checkpoint = torch.load(CF_CHECKPOINT_PATH,
                               map_location='cuda:{}'.format(GPU))
    CF_model.load_state_dict(CF_checkpoint['model_state'])

    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']

    # Train
    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):
            SM.train()
            x, labels, y_mask, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x[:, 10:, :41] = 0  # DELETE METALOG QUE

            # labels_shift: (model can only observe past labels)
            labels_shift = torch.zeros(batch_sz, 20, 1)
            labels_shift[:, 1:, 0] = labels[:, :-1].float()
            #!!! NOLABEL for previous QUERY
            labels_shift[:, 11:, 0] = 0
            # support/query state labels
            sq_state = torch.zeros(batch_sz, 20, 1)
            sq_state[:, :11, 0] = 1
            # compute lastfm_output
            x_audio = x[:, :, 41:].data.clone()
            x_audio = Variable(x_audio, requires_grad=False).cuda(GPU)
            x_emb_lastfm, x_lastfm = CF_model(x_audio)
            x_lastfm = x_lastfm.cpu()
            del x_emb_lastfm

            # Pack x: bx122*20
            x = Variable(
                torch.cat((x_lastfm, x, labels_shift, sq_state),
                          dim=2).permute(0, 2, 1)).cuda(GPU)

            # Forward & update
            y_hat = SM(x)  # y_hat: b*20
            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(
                input=y_hat * y_mask.cuda(GPU),
                target=labels.cuda(GPU) * y_mask.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()

            # Decision
            y_prob = torch.sigmoid(
                y_hat * y_mask.cuda(GPU)).detach().cpu().numpy()  # bx20
            y_pred = (y_prob[:, 10:] >= 0.5).astype(np.int)  # bx10
            y_numpy = labels[:, 10:].numpy()  # bx10
            # Acc
            y_query_mask = y_mask[:, 10:].numpy()
            total_corrects += np.sum((y_pred == y_numpy) * y_query_mask)
            total_query += np.sum(num_query)
            # Restore GPU memory
            del loss, y_hat

            if (session + 1) % 500 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                # Prepare display
                sample_sup = labels[
                    0, :num_support[0]].long().numpy().flatten()
                sample_que = y_numpy[0, :num_query[0]].astype(int)
                sample_pred = y_pred[0, :num_query[0]]
                sample_prob = y_prob[0, 10:10 + num_query[0]]
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))
                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 20000 == 0:
                # Validation
                validate(mval_loader, SM, CF_model, eval_mode=True)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'SM_state': SM.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'SM_opt_state': SM_optim.state_dict(),
                        'SM_sch_state': SM_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, CF_model, eval_mode=True)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'SM_state': SM.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'SM_opt_state': SM_optim.state_dict(),
                'SM_sch_state': SM_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd):

        self.cfg_data = cfg_data

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET

        self.train_loader, self.val_loader, self.restore_transform = dataloader()


        if self.net_name in ['CMTL']:

            # use for gt's class labeling
            self.max_gt_count = 0.
            self.min_gt_count = 0x7f7f7f
            self.num_classes = 10
            self.bin_val = 0.

            self.pre_max_min_bin_val()
            ce_weights = torch.from_numpy(self.pre_weights()).float()

            loss_1_fn = nn.MSELoss()
            
            loss_2_fn = nn.BCELoss(weight=ce_weights)

        self.net = CrowdCounter(cfg.GPU_ID, self.net_name,loss_1_fn,loss_2_fn).cuda()
        self.optimizer = optim.Adam(self.net.CCN.parameters(), lr=cfg.LR, weight_decay=1e-4)
        # self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer, step_size=cfg.NUM_EPOCH_LR_DECAY, gamma=cfg.LR_DECAY)

        self.train_record = {'best_mae': 1e20, 'best_mse': 1e20, 'best_model_name': ''}
        self.timer = {'iter time': Timer(), 'train time': Timer(), 'val time': Timer()}

        self.i_tb = 0
        self.epoch = 0

        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, 'exp', resume=cfg.RESUME)



    def pre_max_min_bin_val(self):
        for i, data in enumerate(self.train_loader, 0):
            if i < 50:
                # for getting the max and min people count
                _, gt_map = data

                for j in range(0, gt_map.size()[0]):
                    temp_count = gt_map[j].sum() / self.cfg_data.LOG_PARA
                    if temp_count > self.max_gt_count:
                        self.max_gt_count = temp_count
                    elif temp_count < self.min_gt_count:
                        self.min_gt_count = temp_count

        print '[max_gt: %.2f min_gt: %.2f]' % (self.max_gt_count, self.min_gt_count)
        self.bin_val = (self.max_gt_count - self.min_gt_count)/float(self.num_classes)

    def pre_weights(self):
        count_class_hist = np.zeros(self.num_classes)
        for i, data in enumerate(self.train_loader, 0):
            if i < 100:
                _, gt_map = data
                for j in range(0, gt_map.size()[0]):
                    temp_count = gt_map[j].sum() / self.cfg_data.LOG_PARA
                    class_idx = min(int(temp_count/self.bin_val), self.num_classes-1)
                    count_class_hist[class_idx] += 1

        wts = count_class_hist
        wts = 1-wts/(sum(wts));
        wts = wts/sum(wts);
        print 'pre_wts:'
        print wts

        return wts


    def online_assign_gt_class_labels(self, gt_map_batch):
        batch = gt_map_batch.size()[0]
        # pdb.set_trace()
        label = np.zeros((batch, self.num_classes), dtype=np.int)

        for i in range(0, batch):

            # pdb.set_trace()
            gt_count = (gt_map_batch[i].sum().item() / self.cfg_data.LOG_PARA)

            # generate gt's label same as implement of CMTL by Viswa
            gt_class_label = np.zeros(self.num_classes, dtype=np.int)
            # bin_val = ((self.max_gt_count - self.min_gt_count)/float(self.num_classes))
            class_idx = min(int(gt_count/self.bin_val), self.num_classes-1)
            gt_class_label[class_idx] = 1
            # pdb.set_trace()
            label[i] = gt_class_label.reshape(1, self.num_classes)


        return torch.from_numpy(label).float()

    def forward(self):

        # self.validate_V1()
        for epoch in range(self.epoch, cfg.MAX_EPOCH):
            self.epoch = epoch
            if epoch > cfg.LR_DECAY_START:
                self.scheduler.step()

            # training
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print 'train time: {:.2f}s'.format(self.timer['train time'].diff)
            print '=' * 20

            # validation
            if epoch % cfg.VAL_FREQ == 0 or epoch > cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50']:
                    self.validate_V1()
                elif self.data_mode is 'WE':
                    self.validate_V2()
                elif self.data_mode is 'GCC':
                    self.validate_V3()
                self.timer['val time'].toc(average=False)
                print 'val time: {:.2f}s'.format(self.timer['val time'].diff)

    def train(self):  # training for all datasets
        self.net.train()
        for i, data in enumerate(self.train_loader, 0):
            # train net
            self.timer['iter time'].tic()
            img, gt_map = data
            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            gt_label = self.online_assign_gt_class_labels(gt_map)
            gt_label = Variable(gt_label).cuda()
            

            self.optimizer.zero_grad()
            pred_map = self.net(img, gt_map, gt_label)
            loss1,loss2 = self.net.loss
            loss = loss1+loss2
            # loss = loss1
            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.writer.add_scalar('train_loss1', loss1.item(), self.i_tb)
                self.writer.add_scalar('train_loss2', loss2.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)
                print '[ep %d][it %d][loss %.8f, %.8f, %.8f][lr %.4f][%.2fs]' % \
                        (self.epoch + 1, i + 1, loss.item(),loss1.item(),loss2.item(), self.optimizer.param_groups[0]['lr'] * 10000,
                 self.timer['iter time'].diff)
                print '        [cnt: gt: %.1f pred: %.2f]' % (gt_map[0].sum().data/self.cfg_data.LOG_PARA, pred_map[0].sum().data/self.cfg_data.LOG_PARA)

    def validate_V1(self):  # validate_V1 for SHHA, SHHB, UCF-QNRF, UCF50

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                gt_label = self.online_assign_gt_class_labels(gt_map)
                gt_label = Variable(gt_label).cuda()

                pred_map = self.net.forward(img, gt_map, gt_label)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                pred_cnt = np.sum(pred_map) / self.cfg_data.LOG_PARA
                gt_count = np.sum(gt_map) / self.cfg_data.LOG_PARA

                loss1,loss2 = self.net.loss
                # loss = loss1.item()+loss2.item()
                loss = loss1.item()
                losses.update(loss)
                maes.update(abs(gt_count - pred_cnt))
                mses.update((gt_count - pred_cnt) * (gt_count - pred_cnt))
                if vi == 0:
                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)

        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)

        print_summary(self.exp_name, [mae, mse, loss], self.train_record)

    def validate_V2(self):  # validate_V2 for WE

        self.net.eval()

        losses = AverageCategoryMeter(5)
        maes = AverageCategoryMeter(5)

        for i_sub, i_loader in enumerate(self.val_loader, 0):

            for vi, data in enumerate(i_loader, 0):
                img, gt_map = data

                with torch.no_grad():
                    img = Variable(img).cuda()
                    gt_map = Variable(gt_map).cuda()

                    pred_map = self.net.forward(img, gt_map)

                    pred_map = pred_map.data.cpu().numpy()
                    gt_map = gt_map.data.cpu().numpy()


                    for i_img in range(pred_map.shape[0]):
                    
                        pred_cnt = np.sum(pred_map[i_img])/self.cfg_data.LOG_PARA
                        gt_count = np.sum(gt_map[i_img])/self.cfg_data.LOG_PARA

                        losses.update(self.net.loss.item(),i_sub)
                        maes.update(abs(gt_count-pred_cnt),i_sub)

                    if vi == 0:
                        vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)

        mae = np.average(maes.avg)
        loss = np.average(losses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, 0, loss],self.train_record,self.log_txt)
        print_summary(self.exp_name, [mae, 0, loss], self.train_record)

    def validate_V3(self):  # validate_V3 for GCC

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        c_maes = {'level': AverageCategoryMeter(9), 'time': AverageCategoryMeter(8), 'weather': AverageCategoryMeter(7)}
        c_mses = {'level': AverageCategoryMeter(9), 'time': AverageCategoryMeter(8), 'weather': AverageCategoryMeter(7)}

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map, attributes_pt = data
            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):

                    pred_cnt = np.sum(pred_map) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map) / self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count - pred_cnt)
                    s_mse = (gt_count - pred_cnt) * (gt_count - pred_cnt)

                    losses.update(self.net.loss.item())
                    maes.update(s_mae)
                    mses.update(s_mse)
                    c_maes['level'].update(s_mae, attributes_pt[i_img][0])
                    c_mses['level'].update(s_mse, attributes_pt[i_img][0])
                    c_maes['time'].update(s_mae, attributes_pt[i_img][1] / 3)
                    c_mses['time'].update(s_mse, attributes_pt[i_img][1] / 3)
                    c_maes['weather'].update(s_mae, attributes_pt[i_img][2])
                    c_mses['weather'].update(s_mse, attributes_pt[i_img][2])

                if vi == 0:
                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)

        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)

        c_mses['level'] = np.sqrt(c_mses['level'].avg)
        c_mses['time'] = np.sqrt(c_mses['time'].avg)
        c_mses['weather'] = np.sqrt(c_mses['weather'].avg)
        print_GCC_summary(self.exp_name, [mae, mse, loss], self.train_record, c_maes, c_mses)
示例#20
0
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd):
        self.cfg_data = cfg_data
        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET
        self.net = CrowdCounter(cfg.GPU_ID, self.net_name).cuda()
        self.optimizer = optim.Adam(self.net.CCN.parameters(), lr=cfg.LR, weight_decay=1e-4)
        # self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer, step_size=cfg.NUM_EPOCH_LR_DECAY, gamma=cfg.LR_DECAY)          

        self.train_record = {'best_mae': 1e20, 'best_mse':1e20, 'best_model_name': ''}

        self.timer = {'iter time' : Timer(),'train time' : Timer(),'val time' : Timer()} 

        self.epoch = 0
        self.i_tb = 0
        
        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        self.train_loader, self.val_loader, self.restore_transform = dataloader()

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.log_txt = None
        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, 'exp', resume=cfg.RESUME)


    def forward(self):

        for epoch in range(self.epoch, cfg.MAX_EPOCH):
            self.epoch = epoch
            if epoch > cfg.LR_DECAY_START:
                self.scheduler.step()
                
            # training    
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print( 'train time: {:.2f}s'.format(self.timer['train time'].diff) )
            print( '='*20 )

            # validation
            if epoch % cfg.VAL_FREQ == 0 or epoch > cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50']:
                    self.validate_V1()
                self.timer['val time'].toc(average=False)
                print( 'val time: {:.2f}s'.format(self.timer['val time'].diff) )


    def train(self):    # training for all datasets
        self.net.train()

        for i, data in enumerate(self.train_loader, 0):
            self.timer['iter time'].tic()
            img, gt_map = data
            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            self.optimizer.zero_grad()
            pred_map = self.net(img, gt_map)
            loss = self.net.loss
            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)
                print( '[ep %d][it %d][loss %.5f][lr %.8f][%.3fs]' % \
                        (self.epoch + 1, i + 1, loss.item(), self.optimizer.param_groups[0]['lr']*1, self.timer['iter time'].diff) )
                print( '        [cnt: gt: %.1f pred: %.2f]' % (gt_map[0].sum().data/self.cfg_data.LOG_PARA, pred_map[0].sum().data/self.cfg_data.LOG_PARA) )

                # write txt file
                if self.log_txt is not None:
                    with open(self.log_txt, 'a') as f:
                        f.write( '[ep %d][it %d][loss %.5f][lr %.8f][%.3fs] \n' % \
                            (self.epoch + 1, i + 1, loss.item(), self.optimizer.param_groups[0]['lr']*1, self.timer['iter time'].diff) )
                        f.write( '        [cnt: gt: %.1f pred: %.2f] \n' % (gt_map[0].sum().data/self.cfg_data.LOG_PARA, pred_map[0].sum().data/self.cfg_data.LOG_PARA) )


    def validate_V1(self): # validate_V1 for SHHA, SHHB, QNRF, UCF50
        self.net.eval()
        
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                
                    pred_cnt = np.sum(pred_map[i_img]) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img]) / self.cfg_data.LOG_PARA

                    losses.update(self.net.loss.item())
                    maes.update(abs(gt_count-pred_cnt))
                    mses.update((gt_count-pred_cnt)*(gt_count-pred_cnt))

                # if vi == 0:
                #     vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)

        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss], self.train_record,self.log_txt)

        print_summary(self.exp_name, [mae, mse, loss], self.train_record)

        # write txt file
        if self.log_txt is not None:
            with open(self.log_txt, 'a') as f:
                f.write( '=' * 50 + '\n')
                f.write( self.exp_name + '\n')
                f.write( '    ' + '-' * 20 + '\n')
                f.write( '    [mae %.2f mse %.2f], [val loss %.4f] \n' % (mae, mse, loss))
                f.write( '    ' + '-' * 20 + '\n')
                f.write( '[best] [model: %s] , [mae %.2f], [mse %.2f] \n' % (self.train_record['best_model_name'], self.train_record['best_mae'], self.train_record['best_mse']) )
                f.write( '=' * 50 + '\n')
                f.write('\n\n')
示例#21
0
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True)  # shuffle은 True로 해야됨 나중에...

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 124950714),  # 20%를 테스트
        batch_size=2048,
        shuffle=False)

    # Init neural net
    #FeatEnc = MLP(input_sz=29, hidden_sz=512, output_sz=64).apply(weights_init).cuda(GPU)
    FeatEnc = MLP(input_sz=29, hidden_sz=256, output_sz=64).cuda(GPU)
    RN = RelationNetwork().cuda(GPU)

    FeatEnc_optim = torch.optim.Adam(FeatEnc.parameters(), lr=LEARNING_RATE)
    RN_optim = torch.optim.Adam(RN.parameters(), lr=LEARNING_RATE)

    FeatEnc_scheduler = StepLR(FeatEnc_optim, step_size=100000, gamma=0.2)
    RN_scheduler = StepLR(RN_optim, step_size=100000, gamma=0.2)

    if args.load_continue_latest is None:
        START_EPOCH = 0

    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['hist_trloss'][-1]))
        FeatEnc.load_state_dict(checkpoint['FE_state'])
        RN.load_state_dict(checkpoint['RN_state'])
        FeatEnc_optim.load_state_dict(checkpoint['FE_opt_state'])
        RN_optim.load_state_dict(checkpoint['RN_opt_state'])
        FeatEnc_scheduler.load_state_dict(checkpoint['FE_sch_state'])
        RN_scheduler.load_state_dict(checkpoint['RN_sch_state'])
        START_EPOCH = checkpoint['ep']

    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):

        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):

            FeatEnc.train()
            RN.train()
            x_sup, x_que, x_log_sup, x_log_que, label_sup, label_que, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS
            x_sup, x_que = Variable(x_sup).cuda(GPU), Variable(x_que).cuda(GPU)
            x_log_sup, x_log_que = Variable(x_log_sup).cuda(GPU), Variable(
                x_log_que).cuda(GPU)
            label_sup = Variable(label_sup).cuda(GPU)

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            x_sup = x_sup.unsqueeze(2)  # 1x7*29 --> 1x7x1*29
            x_que = x_que.unsqueeze(2)  # 1x8*29 --> 1x8x1*29

            # - feature encoder
            x_feat_sup = FeatEnc(x_sup)  # 1x7x1*64
            x_feat_que = FeatEnc(x_que)  # 1x8x1*64

            # - relation network
            y_hat = RN(x_feat_sup, x_feat_que, x_log_sup, x_log_que,
                       label_sup)  # bx8

            # Prepare ground-truth simlarity score and mask
            y_gt = label_que[:, :, 1]
            y_mask = np.zeros((batch_sz, 10), dtype=np.float32)
            for b in np.arange(batch_sz):
                y_mask[b, :num_query[b]] = 1
            y_mask = torch.FloatTensor(y_mask).cuda(GPU)

            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(input=y_hat * y_mask,
                                                      target=y_gt.cuda(GPU) *
                                                      y_mask)
            total_trloss += loss.item()

            # Update Nets
            FeatEnc.zero_grad()
            RN.zero_grad()

            loss.backward()
            #torch.nn.utils.clip_grad_norm_(FeatEnc.parameters(), 0.5)
            #torch.nn.utils.clip_grad_norm_(RN.parameters(), 0.5)

            FeatEnc_optim.step()
            RN_optim.step()

            # Decision
            y_prob = (torch.sigmoid(y_hat) * y_mask).detach().cpu().numpy()
            y_pred = ((torch.sigmoid(y_hat) > 0.5).float() *
                      y_mask).detach().cpu().long().numpy()

            # Prepare display
            sample_sup = label_sup[0, :num_support[0],
                                   1].detach().long().cpu().numpy().flatten()
            sample_que = label_que[0, :num_query[0],
                                   1].long().numpy().flatten()
            sample_pred = y_pred[0, :num_query[0]].flatten()
            sample_prob = y_prob[0, :num_query[0]].flatten()

            # Acc
            total_corrects += np.sum(
                (y_pred == label_que[:, :, 1].long().numpy()) *
                y_mask.cpu().numpy())
            total_query += np.sum(num_query)

            # Restore GPU memory
            del loss, x_feat_sup, x_feat_que, y_hat

            if (session + 1) % 900 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))

                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 4000 == 0:
                # Validation
                validate(mval_loader, FeatEnc, RN, eval_mode=True)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'FE_state': FeatEnc.state_dict(),
                        'RN_state': RN.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'FE_opt_state': FeatEnc_optim.state_dict(),
                        'RN_opt_state': RN_optim.state_dict(),
                        'FE_sch_state': FeatEnc_scheduler.state_dict(),
                        'RN_sch_state': RN_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))

        # Validation
        validate(mval_loader, FeatEnc, RN, eval_mode=True)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'FE_state': FeatEnc.state_dict(),
                'RN_state': RN.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'FE_opt_state': FeatEnc_optim.state_dict(),
                'RN_opt_state': RN_optim.state_dict(),
                'FE_sch_state': FeatEnc_scheduler.state_dict(),
                'RN_sch_state': RN_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
示例#22
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.height,
                                                args.width,
                                                args.train_resizing,
                                                random_horizontal_flip=True,
                                                random_color_jitter=False,
                                                random_gray_scale=False,
                                                random_erasing=True)
    val_transform = utils.get_val_transform(args.height, args.width)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    working_dir = osp.dirname(osp.abspath(__file__))
    source_root = osp.join(working_dir, args.source_root)
    target_root = osp.join(working_dir, args.target_root)

    # source dataset
    source_dataset = datasets.__dict__[args.source](
        root=osp.join(source_root, args.source.lower()))
    sampler = RandomMultipleGallerySampler(source_dataset.train,
                                           args.num_instances)
    train_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=train_transform),
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     sampler=sampler,
                                     pin_memory=True,
                                     drop_last=True)
    train_source_iter = ForeverDataIterator(train_source_loader)
    cluster_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    val_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(source_dataset.query) | set(source_dataset.gallery)),
        root=source_dataset.images_dir,
        transform=val_transform),
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=True)

    # target dataset
    target_dataset = datasets.__dict__[args.target](
        root=osp.join(target_root, args.target.lower()))
    cluster_target_loader = DataLoader(convert_to_pytorch_dataset(
        target_dataset.train,
        root=target_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    test_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(target_dataset.query) | set(target_dataset.gallery)),
        root=target_dataset.images_dir,
        transform=val_transform),
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    n_s_classes = source_dataset.num_train_pids
    args.n_classes = n_s_classes + len(target_dataset.train)
    args.n_s_classes = n_s_classes
    args.n_t_classes = len(target_dataset.train)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    pool_layer = nn.Identity() if args.no_pool else None
    model = ReIdentifier(backbone,
                         args.n_classes,
                         finetune=args.finetune,
                         pool_layer=pool_layer)
    features_dim = model.features_dim

    idm_bn_names = filter_layers(args.stage)
    convert_dsbn_idm(model, idm_bn_names, idm=False)

    model = model.to(device)
    model = DataParallel(model)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])

    # analysis the model
    if args.phase == 'analysis':
        # plot t-SNE
        utils.visualize_tsne(source_loader=val_loader,
                             target_loader=test_loader,
                             model=model,
                             filename=osp.join(logger.visualize_directory,
                                               'analysis', 'TSNE.pdf'),
                             device=device)
        # visualize ranked results
        visualize_ranked_results(test_loader,
                                 model,
                                 target_dataset.query,
                                 target_dataset.gallery,
                                 device,
                                 visualize_dir=logger.visualize_directory,
                                 width=args.width,
                                 height=args.height,
                                 rerank=args.rerank)
        return

    if args.phase == 'test':
        print("Test on target domain:")
        validate(test_loader,
                 model,
                 target_dataset.query,
                 target_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        return

    # create XBM
    dataset_size = len(source_dataset.train) + len(target_dataset.train)
    memory_size = int(args.ratio * dataset_size)
    xbm = XBM(memory_size, features_dim)

    # initialize source-domain class centroids
    source_feature_dict = extract_reid_feature(cluster_source_loader,
                                               model,
                                               device,
                                               normalize=True)
    source_features_per_id = {}
    for f, pid, _ in source_dataset.train:
        if pid not in source_features_per_id:
            source_features_per_id[pid] = []
        source_features_per_id[pid].append(source_feature_dict[f].unsqueeze(0))
    source_centers = [
        torch.cat(source_features_per_id[pid], 0).mean(0)
        for pid in sorted(source_features_per_id.keys())
    ]
    source_centers = torch.stack(source_centers, 0)
    source_centers = F.normalize(source_centers, dim=1)
    model.module.head.weight.data[0:n_s_classes].copy_(
        source_centers.to(device))

    # save memory
    del source_centers, cluster_source_loader, source_features_per_id

    # define optimizer and lr scheduler
    optimizer = Adam(model.module.get_parameters(base_lr=args.lr,
                                                 rate=args.rate),
                     args.lr,
                     weight_decay=args.weight_decay)
    lr_scheduler = StepLR(optimizer, step_size=args.step_size, gamma=0.1)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    # start training
    best_test_mAP = 0.
    for epoch in range(args.start_epoch, args.epochs):
        # run clustering algorithm and generate pseudo labels
        train_target_iter = run_dbscan(cluster_target_loader, model,
                                       target_dataset, train_transform, args)

        # train for one epoch
        print(lr_scheduler.get_lr())
        train(train_source_iter, train_target_iter, model, optimizer, xbm,
              epoch, args)

        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
            # remember best mAP and save checkpoint
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch
                }, logger.get_checkpoint_path(epoch))
            print("Test on target domain...")
            _, test_mAP = validate(test_loader,
                                   model,
                                   target_dataset.query,
                                   target_dataset.gallery,
                                   device,
                                   cmc_flag=True,
                                   rerank=args.rerank)
            if test_mAP > best_test_mAP:
                shutil.copy(logger.get_checkpoint_path(epoch),
                            logger.get_checkpoint_path('best'))
            best_test_mAP = max(test_mAP, best_test_mAP)

        # update lr
        lr_scheduler.step()

    print("best mAP on target = {}".format(best_test_mAP))
    logger.close()
示例#23
0
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd):

        self.cfg_data = cfg_data

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET
        self.net = CrowdCounter(cfg.GPU_ID,self.net_name).cuda()
        self.optimizer = optim.Adam(self.net.CCN.parameters(), lr=cfg.LR, weight_decay=1e-4)
        # self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer, step_size=cfg.NUM_EPOCH_LR_DECAY, gamma=cfg.LR_DECAY)          

        self.train_record = {'best_mae': 1e20, 'best_mse':1e20, 'best_model_name': ''}
        self.timer = {'iter time' : Timer(),'train time' : Timer(),'val time' : Timer()} 

        self.epoch = 0
        self.i_tb = 0
        
        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        self.train_loader, self.val_loader, self.restore_transform = dataloader()

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, 'exp', resume=cfg.RESUME)


    def forward(self):

        # self.validate_V3()
        for epoch in range(self.epoch,cfg.MAX_EPOCH):
            self.epoch = epoch
            if epoch > cfg.LR_DECAY_START:
                self.scheduler.step()
                
            # training    
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print 'train time: {:.2f}s'.format(self.timer['train time'].diff)
            print '='*20

            # validation
            if epoch%cfg.VAL_FREQ==0 or epoch>cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50']:
                    self.validate_V1()
                elif self.data_mode is 'WE':
                    self.validate_V2()
                elif self.data_mode is 'GCC':
                    self.validate_V3()
                self.timer['val time'].toc(average=False)
                print 'val time: {:.2f}s'.format(self.timer['val time'].diff)


    def train(self): # training for all datasets
        self.net.train()
        for i, data in enumerate(self.train_loader, 0):
            self.timer['iter time'].tic()
            img, gt_map = data
            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            self.optimizer.zero_grad()
            pred_map = self.net(img, gt_map)
            loss = self.net.loss
            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)
                print '[ep %d][it %d][loss %.4f][lr %.4f][%.2fs]' % \
                        (self.epoch + 1, i + 1, loss.item(), self.optimizer.param_groups[0]['lr']*10000, self.timer['iter time'].diff)
                print '        [cnt: gt: %.1f pred: %.2f]' % (gt_map[0].sum().data/self.cfg_data.LOG_PARA, pred_map[0].sum().data/self.cfg_data.LOG_PARA)            


    def validate_V1(self):# validate_V1 for SHHA, SHHB, UCF-QNRF, UCF50

        self.net.eval()
        
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img,gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                
                    pred_cnt = np.sum(pred_map[i_img])/self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img])/self.cfg_data.LOG_PARA

                    
                    losses.update(self.net.loss.item())
                    maes.update(abs(gt_count-pred_cnt))
                    mses.update((gt_count-pred_cnt)*(gt_count-pred_cnt))
                if vi==0:
                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
            
        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)
        print_summary(self.exp_name,[mae, mse, loss],self.train_record)


    def validate_V2(self):# validate_V2 for WE

        self.net.eval()

        losses = AverageCategoryMeter(5)
        maes = AverageCategoryMeter(5)

        roi_mask = []
        from datasets.WE.setting import cfg_data 
        from scipy import io as sio
        for val_folder in cfg_data.VAL_FOLDER:

            roi_mask.append(sio.loadmat(os.path.join(cfg_data.DATA_PATH,'test',val_folder + '_roi.mat'))['BW'])
        
        for i_sub,i_loader in enumerate(self.val_loader,0):

            mask = roi_mask[i_sub]
            for vi, data in enumerate(i_loader, 0):
                img, gt_map = data

                with torch.no_grad():
                    img = Variable(img).cuda()
                    gt_map = Variable(gt_map).cuda()

                    pred_map = self.net.forward(img,gt_map)

                    pred_map = pred_map.data.cpu().numpy()
                    gt_map = gt_map.data.cpu().numpy()

                    for i_img in range(pred_map.shape[0]):
                    
                        pred_cnt = np.sum(pred_map[i_img])/self.cfg_data.LOG_PARA
                        gt_count = np.sum(gt_map[i_img])/self.cfg_data.LOG_PARA

                        losses.update(self.net.loss.item(),i_sub)
                        maes.update(abs(gt_count-pred_cnt),i_sub)
                    if vi==0:
                        vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
            
        mae = np.average(maes.avg)
        loss = np.average(losses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mae_s1', maes.avg[0], self.epoch + 1)
        self.writer.add_scalar('mae_s2', maes.avg[1], self.epoch + 1)
        self.writer.add_scalar('mae_s3', maes.avg[2], self.epoch + 1)
        self.writer.add_scalar('mae_s4', maes.avg[3], self.epoch + 1)
        self.writer.add_scalar('mae_s5', maes.avg[4], self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)
        print_WE_summary(self.log_txt,self.epoch,[mae, 0, loss],self.train_record,maes)





    def validate_V3(self):# validate_V3 for GCC

        self.net.eval()
        
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        c_maes = {'level':AverageCategoryMeter(9), 'time':AverageCategoryMeter(8),'weather':AverageCategoryMeter(7)}
        c_mses = {'level':AverageCategoryMeter(9), 'time':AverageCategoryMeter(8),'weather':AverageCategoryMeter(7)}


        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map, attributes_pt = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()


                pred_map = self.net.forward(img,gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                
                    pred_cnt = np.sum(pred_map[i_img])/self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img])/self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count-pred_cnt)
                    s_mse = (gt_count-pred_cnt)*(gt_count-pred_cnt)

                    losses.update(self.net.loss.item())
                    maes.update(s_mae)
                    mses.update(s_mse)   
                    attributes_pt = attributes_pt.squeeze() 
                    c_maes['level'].update(s_mae,attributes_pt[i_img][0])
                    c_mses['level'].update(s_mse,attributes_pt[i_img][0])
                    c_maes['time'].update(s_mae,attributes_pt[i_img][1]/3)
                    c_mses['time'].update(s_mse,attributes_pt[i_img][1]/3)
                    c_maes['weather'].update(s_mae,attributes_pt[i_img][2])
                    c_mses['weather'].update(s_mse,attributes_pt[i_img][2])


                if vi==0:
                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
            
        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)


        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)


        print_GCC_summary(self.log_txt,self.epoch,[mae, mse, loss],self.train_record,c_maes,c_mses)
示例#24
0
def main(args):
    #
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    set_random_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    if args.dataset == 'mnist':
        train_data = get_dataset('mnist-train',  args.dataroot)
        test_data = get_dataset('mnist-test',  args.dataroot)
        train_tr = test_tr = get_transform('mnist_normalize')

    if args.dataset == 'cifar10':
        train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize'
        train_data = get_dataset('cifar10-train',  args.dataroot)
        test_data = get_dataset('cifar10-test',  args.dataroot)
        train_tr = get_transform(train_tr_name)
        test_tr = get_transform('cifar_normalize')
        
    if args.dataset == 'cifar-fs-train':
        train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize'
        train_data = get_dataset('cifar-fs-train-train',  args.dataroot)
        test_data = get_dataset('cifar-fs-train-test',  args.dataroot)
        train_tr = get_transform(train_tr_name)
        test_tr = get_transform('cifar_normalize')

    if args.dataset == 'miniimagenet':
        train_data = get_dataset('miniimagenet-train-train', args.dataroot)
        test_data = get_dataset('miniimagenet-train-test', args.dataroot)
        train_tr = get_transform('cifar_augment_normalize_84' if args.data_augmentation else 'cifar_normalize')
        test_tr = get_transform('cifar_normalize')
    

    model = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device)
    if args.ckpt_path != '':
        loaded = torch.load(args.ckpt_path)
        model.load_state_dict(loaded)
        ipdb.set_trace()
    if args.eval:
        acc = test(args, model, device, test_loader, args.n_eval_batches)
        print("Eval Acc: ", acc)
        sys.exit()

    # Trace logging
    mkdir(args.output_dir)
    eval_fieldnames = ['global_iteration','val_acc','train_acc']
    eval_logger = CSVLogger(every=1,
                                 fieldnames=eval_fieldnames,
                                 resume=args.resume,
                                 filename=os.path.join(args.output_dir, 'eval_log.csv'))
    wandb.run.name = os.path.basename(args.output_dir)
    wandb.run.save()
    wandb.watch(model)

    if args.optim == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
    elif args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
    if args.dataset == 'mnist':
        scheduler = StepLR(optimizer, step_size=1, gamma=.7)
    else:
        scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

    start_epoch = 1
    if args.resume:
        last_ckpt_path = os.path.join(args.output_dir, 'last_ckpt.pt')
        if os.path.exists(last_ckpt_path):
            loaded = torch.load(last_ckpt_path)
            model.load_state_dict(loaded['model_sd'])
            optimizer.load_state_dict(loaded['optimizer_sd'])
            scheduler.load_state_dict(loaded['scheduler_sd'])
            start_epoch = loaded['epoch']

    # It's important to set seed again before training b/c dataloading code
    # might have reset the seed.
    set_random_seed(args.seed)
    best_val = 0
    if args.db: 
        scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4], gamma=0.1)
        args.epochs = 5
    for epoch in range(start_epoch, args.epochs + 1):
        if epoch % args.ckpt_every == 0:
            torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_{epoch}.pt"))

        stats_dict = {'global_iteration':epoch}
        val = stats_dict['val_acc'] = test(args, model, device, test_data, test_tr, args.n_eval_batches)
        stats_dict['train_acc'] = test(args, model, device, train_data, test_tr, args.n_eval_batches)
        grid = make_grid(torch.stack([train_tr(x) for x in train_data['x'][:30]]), nrow=6).permute(1,2,0).numpy()
        img_dict = {"examples": [wandb.Image(grid, caption="Data batch")]}
        wandb.log(stats_dict)
        wandb.log(img_dict)
        eval_logger.writerow(stats_dict)
        plot_csv(eval_logger.filename, os.path.join(args.output_dir, 'iteration_plots.png'))

        train(args, model, device, train_data, train_tr, optimizer, epoch)
        
        scheduler.step(epoch)

        if val > best_val: 
            best_val = val
            torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_best.pt"))

        # For `resume`
        model.cpu()
        torch.save({
            'model_sd': model.state_dict(),
            'optimizer_sd': optimizer.state_dict(), 
            'scheduler_sd': scheduler.state_dict(), 
            'epoch': epoch + 1
            }, os.path.join(args.output_dir, "last_ckpt.pt"))
        model.to(device)
示例#25
0
class BDDVAgent(
        LearningAgent):  # ++ Extend Learning agent

    def __init__(self, cfg):
        super(BDDVAgent, self).__init__(cfg)

        use_cuda = self._use_cuda  # ++ Parent class already saves some configuration variables
        # ++ All parent variables should start with _.

        # -- Get necessary variables from cfg
        self.cfg = cfg

        # -- Initialize model
        model_class = get_models(cfg.model)

        input_shape = cfg.data_info.image_shape
        input_shape[0] *= cfg.data_info.frame_seq_len
        self.model = model_class[0](cfg, input_shape, cfg.model.nr_bins)
        # ++ All models receive as parameters (configuration namespace, input data size,
        # ++ output data size)

        self._models.append(
            self.model)  # -- Add models & optimizers to base for saving

        # ++ After adding model you can set the agent to cuda mode
        # ++ Parent class already makes some adjustments. E.g. turns model to cuda mode
        if use_cuda:
            self.cuda()

        self._bins = np.arange(-1.0, 1.0, 2.0 / cfg.model.nr_bins)
        # -- Initialize optimizers
        self.optimizer = self.get_optim(cfg.train.algorithm,
                                        cfg.train.algorithm_args, self.model)
        self.scheduler = StepLR(self.optimizer, cfg.train.step_size,
                                cfg.train.decay)
        self._optimizers.append(
            self.optimizer)  # -- Add models & optimizers to base for saving
        # -- Change settings from parent class
        # ++ Parent class automatically initializes 4 metrics: loss/acc for train/test
        # ++ E.g switch metric slope
        self.set_eval_metric_comparison(True)

        # ++ E.g. to add variable name to be saved at checkpoints
        self._save_data.append("scheduler")

        self._tensorboard_model = False
        self.loss_values_train = []
        self.loss_values_test = []

        ##### Make directories and shit for demo########

        self.img_dir = os.getcwd() + "/" + image_dir
        self.act_dir = os.getcwd() + "/" + activations_dir
        self.steer_dir = os.getcwd() + "/" + steer_distr_dir

        if not os.path.isdir(self.img_dir):
            os.mkdir(self.img_dir)
        if not os.path.isdir(self.act_dir):
            os.mkdir(self.act_dir)
        if not os.path.isdir(self.steer_dir):
            os.mkdir(self.steer_dir)
        self.nr_img = 0

        ################################################

        super(BDDVAgent, self).__end_init__()

    def _session_init(self):
        if self._is_train:
            self.optimizer.zero_grad()

    def _train(self, data_loader):
        """
        Considering a dataloader (loaded from config.)
        Implement the training loop.
        :return training loss metric & other information
        """
        optimizer = self.optimizer
        scheduler = self.scheduler
        use_cuda = self._use_cuda
        model = self.model
        criterion = self._get_criterion
        branches = self.model.get_branches(use_cuda)
        train_loss = 0

        progress_bar = ProgressBar(
            'Loss: %(loss).3f', dict(loss=0), len(data_loader))

        for batch_idx, (images, speed, steer_distr, mask) in enumerate(data_loader):
            optimizer.zero_grad()
            images = to_cuda(images, use_cuda)
            speed_target = to_cuda(speed, use_cuda)
            steer_distr = to_cuda(steer_distr, use_cuda)
            inter_output, speed_output, _ = model(images, speed_target)

            output = to_cuda(torch.zeros((mask.shape[0], self.cfg.model.nr_bins)), use_cuda)

            # Reshape mask to use it for selecting frames at each moment
            mask = mask.reshape((-1, mask.shape[0]))

            for i in range(0, len(branches)):
                # Hardcode for non-temporal case for now
                filter_ = (mask[0] == i)
                if not np.all(filter_ == False):
                    output[filter_] = branches[i](inter_output[filter_])

            loss = criterion(output, speed_output, speed_target, steer_distr)

            loss.backward()
            train_loss += loss.item()

            optimizer.step()
            scheduler.step()
            progress_bar.update(
                batch_idx, dict(loss=(train_loss / (batch_idx + 1))))

            self.loss_values_train.append(loss.item())

            ################### TensorBoard Shit ################################

            #loss function

            #self._writer.add_scalar(
            #    "loss_function", loss.item(),
            #    batch_idx + self._train_epoch * len(data_loader))

            #model
            #if self._tensorboard_model is False:
            #    self._tensorboard_model = True
            #    self._writer.add_graph(model, (images, speed_target))

            #####################################################################

        progress_bar.finish()

        return train_loss, {}

    def _get_criterion(self, branch_outputs, speed_outputs,
                       speed_target, steer_distr):
        loss1_steer = torch.nn.functional.mse_loss(
            branch_outputs, steer_distr, size_average=False)

        loss1 = loss1_steer

        loss2 = (speed_outputs - speed_target) * (speed_outputs - speed_target)
        loss2 = loss2.sum()# / branch_outputs.shape[0]

        loss = (0.95 * loss1 + 0.05 * loss2) / branch_outputs.shape[0]
        return loss

    def _test(self, data_loader):
        """
        Considering a dataloader (loaded from config.)
        Implement the testing loop.
        """
        use_cuda = self._use_cuda
        model = self.model
        criterion = self._get_criterion
        branches = self.model.get_branches(use_cuda)
        test_loss = 0

        progress_bar = ProgressBar(
            'Loss: %(loss).3f', dict(loss=0), len(data_loader))

        for batch_idx, (images, speed, steer_distr, mask) in enumerate(data_loader):
            images = to_cuda(images, use_cuda)
            speed_target = to_cuda(speed, use_cuda)
            steer_distr = to_cuda(steer_distr, use_cuda)

            inter_output, speed_output, _ = model(images, speed_target)

            output = to_cuda(torch.zeros((mask.shape[0], self.cfg.model.nr_bins)), use_cuda)

            # Reshape mask to use it for selecting frames at each moment
            mask = mask.reshape((-1, mask.shape[0]))

            for i in range(0, len(branches)):
                # Hardcode for non-temporal case for now
                filter_ = (mask[0] == i)
                if not np.all(filter_ == False):
                    output[filter_] = branches[i](inter_output[filter_])

            loss = criterion(output, speed_output, speed_target, steer_distr)

            test_loss += loss.item()

            self.loss_values_test.append(loss.item())

            progress_bar.update(
                batch_idx, dict(loss=(test_loss / (batch_idx + 1))))

        progress_bar.finish()

        return test_loss, None, {}

    def _get_steer_from_bins(self, steer_vector):
        # Pass the steer values through softmax_layer and get the bin index
        bin_index = torch.nn.functional.softmax(steer_vector).argmax()
        #bin_index = steer_vector.argmax()
        plt.plot(self._bins + 1.0 / len(self._bins),
                 torch.nn.functional.softmax(steer_vector).data[0].numpy())
        plt.show(block=False)

        plt.draw()
        plt.pause(0.0001)
        #plt.savefig(self.steer_dir + "/distr_" + str(self.nr_img) + ".png")
        plt.gcf().clear()
        #get steer_value from bin
        return self._bins[bin_index] + 1.0 / len(self._bins)

    def _show_activation_image(self, raw_activation, image_activation):
        activation_map = raw_activation.data[0, 0].cpu().numpy()
        activation_map = (activation_map - np.min(activation_map)
                          ) / np.max(activation_map) - np.min(activation_map)

        activation_map = (activation_map * 255.0)

        if image_activation.shape[0] != activation_map.shape[0]:
            activation_map = scipy.misc.imresize(
                activation_map,
                [image_activation.shape[0], image_activation.shape[1]])

        image_activation[:, :, 1] += activation_map.astype(np.uint8)
        activation_map = cv2.applyColorMap(
            activation_map.astype(np.uint8), cv2.COLORMAP_JET)

        image_activation = cv2.resize(image_activation, (720, 460),
                                      cv2.INTER_AREA)
        image_activation = cv2.cvtColor(image_activation, cv2.COLOR_RGB2BGR)

        activation_map = cv2.resize(activation_map, (720, 460), cv2.INTER_AREA)

        cv2.imshow("activation",
                   np.concatenate((image_activation, activation_map), axis=1))

        if cv2.waitKey(1) & 0xFF == ord('q'):
            return

    def run_image(self, image_raw, speed, cmd):
        self.set_eval_mode()

        image = np.transpose(image_raw, (2, 0, 1)).astype(np.float32)
        image = np.multiply(image, 1.0 / 127.5) - 1
        image = to_cuda(torch.from_numpy(image), self._use_cuda)
        image = image.unsqueeze(0)
        speed = to_cuda(torch.Tensor([speed / 90.0]), self._use_cuda)
        speed = speed.unsqueeze(0)

        branches = self.model.get_branches(self._use_cuda)

        inter_output, speed_output, activation_map = self.model(image, speed)

        output = branches[cmd](inter_output)

        steer_angle = self._get_steer_from_bins(output)
        speed_output = speed_output.data.cpu()[0].numpy()
        return steer_angle, speed_output[0] * 90, activation_map

    def run_1step(self, image_raw, speed, cmd):
        image = np.transpose(image_raw, (2, 0, 1)).astype(np.float32)
        image = np.multiply(image, 1.0 / 127.5) - 1
        image = to_cuda(torch.from_numpy(image), self._use_cuda)
        image = image.unsqueeze(0)
        speed = to_cuda(torch.Tensor([speed / 90.0]), self._use_cuda)
        speed = speed.unsqueeze(0)

        branches = self.model.get_branches(self._use_cuda)

        inter_output, speed_output, activation_map = self.model(image, speed)

        if self.cfg.activations:
            self._show_activation_image(activation_map, np.copy(image_raw))

        output = branches[cmd](inter_output)

        steer_angle = self._get_steer_from_bins(output)
        speed_output = speed_output.data.cpu()[0].numpy()
        return steer_angle, speed_output[0] * 90

    def _eval_episode(self, file_name):
        video_file = file_name[0]
        info_file = file_name[1]
        info = pd.read_csv(info_file)

        nr_images = len(info)
        previous_speed = info['linear_speed'][0]

        general_mse = steer_mse = 0

        # Determine steering angles and commands
        helper = DatasetHelper(None, None, None, self.cfg.dataset)
        frame_indices = range(len(info))
        course = info['course']
        linear_speed = info['linear_speed']
        angles, cmds = helper.get_steer(frame_indices, course, linear_speed)

        # Open video to read frames
        vid = cv2.VideoCapture(video_file)

        for index in range(nr_images):
            ret, frame = vid.read()
            if not ret:
                print('Could not retrieve frame')
                return None, None

            gt_speed = linear_speed[index]
            gt_steer = angles[index]

            predicted_steer, predicted_speed = self.run_1step(
                frame, previous_speed, cmds[index])

            steer = (predicted_steer - gt_steer) * (predicted_steer - gt_steer)
            speed = (predicted_speed - gt_speed) * (predicted_speed - gt_speed)
            steer_mse += steer

            general_mse += 0.05 * speed + 0.95 * steer

            log.info("Frame number {}".format(index))
            log.info("Steer: predicted {}, ground_truth {}".format(
                predicted_steer, gt_steer))

            log.info("Speed: predicted {}, ground_truth {}".format(
                predicted_speed, gt_speed))

            previous_speed = gt_speed

        vid.release()

        general_mse /= float(nr_images)
        steer_mse /= float(nr_images)

        return general_mse, steer_mse

    def eval_agent(self):
        self.set_eval_mode()

        f = open(self._save_path + "/eval_results.txt", "wt")
        data_files = sorted(os.listdir(self.cfg.dataset.dataset_test_path))
        video_files = []
        for file in data_files:
            info_file = file.split('.')[0] + '.csv'
            video_files.append((os.path.join(self.cfg.dataset.dataset_test_path, file),
                os.path.join(self.cfg.dataset.info_test_path, info_file)))
        eval_results = []

        mean_mse = mean_steer = 0
        for video_file in video_files:
            general_mse, steer_mse = self._eval_episode(video_file)
            eval_results.append((general_mse, steer_mse))
            mean_mse += general_mse
            mean_steer += steer_mse

            f.write(
                "****************Evaluated {} *******************\n".format(
                    video_file))
            f.write("Mean squared error is {}\n".format(str(general_mse)))
            f.write("Mean squared error for steering is {}\n".format(
                str(steer_mse)))
            f.write("************************************************\n\n")
            f.flush()

        mean_mse /= float(len(video_files))
        mean_steer /= float(len(video_files))

        std_mse = std_steer = 0
        for i in range(len(video_files)):
            std_mse += (eval_results[i][0] - mean_mse) * (
                eval_results[i][0] - mean_mse)
            std_steer += (eval_results[i][2] - mean_steer) * (
                eval_results[i][2] - mean_steer)

        std_mse /= float(len(video_files))
        std_steer /= float(len(video_files))

        std_mse = math.sqrt(std_mse)
        std_steer = math.sqrt(std_steer)

        f.write("****************Final Evaluation *******************\n")
        f.write("Mean squared error is {} with standard deviation {}\n".format(
            str(mean_mse), str(std_mse)))
        f.write(
            "Mean squared error for steering is {} with standard deviation {}\n".
            format(str(steer_mse), str(std_steer)))
        f.write("******************************************************")
        f.flush()
        f.close()

    def _control_function(self, image_input_raw, real_speed, control_input):
        """
        Implement for carla simulator run.
        :return: steer, acc, brake
        """
        print("Control input is {}".format(control_input))

        image_input = scipy.misc.imresize(image_input_raw, [
            self.cfg.data_info.image_shape[1],
            self.cfg.data_info.image_shape[2]
        ])
        image_input = np.transpose(image_input, (2, 0, 1)).astype(np.float32)
        image_input = np.multiply(image_input, 1.0 / 127.5) - 1.0
        image_input = torch.from_numpy(image_input)
        image_input = image_input.unsqueeze(0)
        speed = torch.Tensor([real_speed / 25.0])
        speed = speed.unsqueeze(0)

        branches = self.model.get_branches(self._use_cuda)

        inter_output, predicted_speed, activation_map = self.model(
            image_input, speed)

        if self.cfg.activations:
            self._show_activation_image(activation_map,
                                        np.copy(image_input_raw))

        if control_input == 2 or control_input == 0:
            output = branches[1](inter_output)
        elif control_input == 3:
            output = branches[2](inter_output)
        elif control_input == 4:
            output = branches[3](inter_output)
        else:
            output = branches[4](inter_output)

        steer = self._get_steer_from_bins(output[:, :-2])
        output = output.data.cpu()[0].numpy()
        acc, brake = output[-2], output[-1]


        predicted_speed = predicted_speed.data[0].numpy()
        real_predicted = predicted_speed * 25.0

        if real_speed < 2.0 and real_predicted > 3.0:
            acc = 1 * (5.6 / 25.0 - real_speed / 25.0) + acc
            brake = 0.0

        self.nr_img += 1
        return steer, acc, brake

    def _set_eval_mode(self):
        """
        Custom configuration when changing to evaluation mode
        """
        if self.cfg.activations:
            self.model.set_forward('forward_deconv')
        else:
            self.model.set_forward('forward_simple')
        if self._use_cuda:
            self.cuda()

    def _set_train_mode(self):
        """
        Custom configuration when changing to train mode
        """
        self.model.set_forward('forward_simple')
        if self._use_cuda:
            self.cuda()

    def _save(self, save_data, path):
        """
        Called when saving agent state. Agent already saves variables defined in the list
        self._save_data and other default options.
        :param save_data: Pre-loaded dictionary with saved data. Append here other data
        :param path: Path to folder where other custom data can be saved
        :return: should return default save_data dictionary to be saved
        """
        save_data['scheduler_state'] = self.scheduler.state_dict()
        save_data['train_epoch'] = self._train_epoch
        save_data['loss_value_train'] = self.loss_values_train
        save_data['loss_value_test'] = self.loss_values_test

        return save_data

    def _resume(self, agent_check_point_path, saved_data):
        """
        Custom resume scripts should be implemented here
        :param agent_check_point_path: Path of the checkpoint resumed
        :param saved_data: loaded checkpoint data (dictionary of variables)
        """
        self.scheduler.load_state_dict(saved_data['scheduler_state'])
        self.scheduler.optimizer = self.optimizer
        self.model = self._models[0]
        self.optimizer = self._optimizers[0]
        self._train_epoch = saved_data['train_epoch']
        self.loss_values_train = saved_data['loss_value_train']
        self.loss_values_test = saved_data['loss_value_test']
        if not self._use_cuda:
            self.model.cpu()
示例#26
0
def train(train_dir, model_dir, config_path, checkpoint_path,
          n_steps, save_every, test_every, decay_every,
          n_speakers, n_utterances, seg_len):
    """Train a d-vector network."""

    # setup
    total_steps = 0

    # load data
    dataset = SEDataset(train_dir, n_utterances, seg_len)
    train_set, valid_set = random_split(dataset, [len(dataset)-2*n_speakers,
                                                  2*n_speakers])
    train_loader = DataLoader(train_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    train_iter = iter(train_loader)

    assert len(train_set) >= n_speakers
    assert len(valid_set) >= n_speakers
    print(f"Training starts with {len(train_set)} speakers. "
          f"(and {len(valid_set)} speakers for validation)")

    # build network and training tools
    dvector = DVector().load_config_file(config_path)
    criterion = GE2ELoss()
    optimizer = SGD(list(dvector.parameters()) +
                    list(criterion.parameters()), lr=0.01)
    scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5)

    # load checkpoint
    if checkpoint_path is not None:
        ckpt = torch.load(checkpoint_path)
        total_steps = ckpt["total_steps"]
        dvector.load_state_dict(ckpt["state_dict"])
        criterion.load_state_dict(ckpt["criterion"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])

    # prepare for training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dvector = dvector.to(device)
    criterion = criterion.to(device)
    writer = SummaryWriter(model_dir)
    pbar = tqdm.trange(n_steps)

    # start training
    for step in pbar:

        total_steps += 1

        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)

        embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)

        loss = criterion(embd)

        optimizer.zero_grad()
        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            list(dvector.parameters()) + list(criterion.parameters()), max_norm=3)
        dvector.embedding.weight.grad.data *= 0.5
        criterion.w.grad.data *= 0.01
        criterion.b.grad.data *= 0.01

        optimizer.step()
        scheduler.step()

        pbar.set_description(f"global = {total_steps}, loss = {loss:.4f}")
        writer.add_scalar("Training loss", loss, total_steps)
        writer.add_scalar("Gradient norm", grad_norm, total_steps)

        if (step + 1) % test_every == 0:
            batch = next(iter(valid_loader))
            embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)
            loss = criterion(embd)
            writer.add_scalar("validation loss", loss, total_steps)

        if (step + 1) % save_every == 0:
            ckpt_path = os.path.join(model_dir, f"ckpt-{total_steps}.tar")
            ckpt_dict = {
                "total_steps": total_steps,
                "state_dict": dvector.state_dict(),
                "criterion": criterion.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }
            torch.save(ckpt_dict, ckpt_path)

    print("Training completed.")
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    #loader = import_module('data.' + args.dataset).Data(args)

    # while(1):
    #     a=1

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_train')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_val')
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]))

    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=batch_sizes, shuffle=True,
    #     num_workers=8, pin_memory=True, sampler=None)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_val_rec')


    train_queue = getTrainValDataset(traindir, valdir, batch_size=batch_size, val_batch_size=batch_size,
                                     num_shards=num_gpu, workers=num_workers)
    valid_queue = getTestDataset(valdir, test_batch_size=batch_size, num_shards=num_gpu,
                                 workers=num_workers)

    #loader = cifar100(args)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))


    if args.pruned:
        mask = checkpoint['mask']
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    # model = torchvision.models.resnet18()

    with torch.cuda.device(0):
        flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
        print('Flops:  ' + flops)
        print('Params: ' + params)
    pruned_dir = args.pruned_dir
    checkpoint_pruned = torch.load(pruned_dir, map_location=torch.device(f"cuda:{args.gpus[0]}"))
    model = torch.nn.DataParallel(model)
    #
    # new_state_dict_pruned = OrderedDict()
    # for k, v in checkpoint_pruned.items():
    #     name = k[7:]
    #     new_state_dict_pruned[name] = v
    # model.load_state_dict(new_state_dict_pruned)

    model.load_state_dict(checkpoint_pruned['state_dict_s'])

    test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test)
    logging.info('Simply test after prune: %e ', test_prec1)
    logging.info('Model size: %e ', get_parameters_size(model)/1e6)

    exit()

    if args.test_only:
        return
    param_s = [param for name, param in model.named_parameters() if 'mask' not in name]
    #optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.00001, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer = optim.SGD(param_s, lr=1e-5, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.num_epochs))


    model_kd = None
    if kd_flag:
        model_kd = ResNet101()
        ckpt_kd = torch.load('resnet101.t7', map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_kd = ckpt_kd['net']
        new_state_dict_kd = OrderedDict()
        for k, v in state_dict_kd.items():
            name = k[7:]
            new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
        model_kd.load_state_dict(new_state_dict_kd)
        model_kd = model_kd.to(args.gpus[1])

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict_s'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))
    #print(model.named_parameters())
    #for name, param in model.named_parameters():
        #print(name)
    for epoch in range(start_epoch, 60):
        scheduler.step()#scheduler.step(epoch)
        t1 = time.time()
        train(args, train_queue, model, criterion, optimizer, writer_train, epoch, model_kd)
        test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test, epoch)
        t2 = time.time()
        print(epoch, t2 - t1)
        logging.info('TEST Top1: %e Top5: %e ', test_prec1, test_prec5)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
        logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, is_best)
        train_queue.reset()
        valid_queue.reset()

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
示例#28
0
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    # Create model
    print('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        
    if args.pruned:
        mask = checkpoint['mask']
        pruned = sum([1 for m in mask if mask == 0])
        print(f"Pruned / Total: {pruned} / {len(mask)}")
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test)
    print(f"Simply test after prune {test_prec1:.3f}")
    
    if args.test_only:
        return 

    if args.keep_grad:
        for name, weight in model.named_parameters():
            if 'mask' in name:
                weight.requires_grad = False

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    for epoch in range(start_epoch, args.num_epochs):
        scheduler.step(epoch)

        train(args, loader.loader_train, model, criterion, optimizer, writer_train, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test, epoch)

        is_best_finetune = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, False, is_best_finetune)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
示例#29
0
def main():  
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True,
                                      data_sel=(0, 99965071), # 80% 트레인
                                      batch_size=TR_BATCH_SZ,
                                      shuffle=True,
                                      seq_mode=True) # seq_mode implemented  
    
    mval_loader  = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True, # True, because we use part of trainset as testset
                                      data_sel=(99965071, 104965071),#(99965071, 124950714), # 20%를 테스트
                                      batch_size=TS_BATCH_SZ,
                                      shuffle=False,
                                      seq_mode=True) 
    
    # Load Teacher net
    SMT = SeqModel().cuda(GPU) 
    checkpoint = torch.load(FPATH_T_NET_CHECKPOINT, map_location='cuda:{}'.format(GPU))
    tqdm.write("Loading saved teacher model from '{0:}'... loss: {1:.6f}".format(FPATH_T_NET_CHECKPOINT,checkpoint['loss']))
    SMT.load_state_dict(checkpoint['SM_state'])
    
    SMT_Enc  = nn.Sequential(*list(SMT.children())[:1]).cuda(GPU)
    #SMT_EncFeat = nn.Sequential(*list(SMT.children())[:2])
    
    
    # Init Student net --> copy classifier from the Teacher net
    SM = SeqModel_Student().cuda(GPU)
    SM.feature = deepcopy(SMT.feature)
#    for p in list(SM.feature.parameters()):
#        p.requires_grad = False
    SM.classifier = deepcopy(SMT.classifier)
#    SM.classifier.weight.requires_grad = False
#    SM.classifier.bias.requires_grad = False
    SM = SM.cuda(GPU)
    Distill_parameters = SM.enc.parameters()
    Classifier_parameters = [{'params': SM.feature.parameters()},
                              {'params': SM.classifier.parameters()}]
    
    SM_optim = torch.optim.Adam(Distill_parameters, lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.9)  
    SM2_optim = torch.optim.Adam(Classifier_parameters, lr=LEARNING_RATE)
    
    
    
    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0        
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),key=os.path.getctime)  
        checkpoint = torch.load(latest_fpath, map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(latest_fpath,checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']
        
    # Train    
    for epoch in trange(START_EPOCH, EPOCHS, desc='epochs', position=0, ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query    = 0
        total_trloss   = 0
        for session in trange(len(tr_sessions_iter), desc='sessions', position=1, ascii=True):
            SMT.eval(); # Teacher-net
            SM.train(); # Student-net
            x, labels, y_mask, num_items, index = tr_sessions_iter.next() # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS 
            
            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...        
            num_support = num_items[:,0].detach().numpy().flatten() # If num_items was odd number, query has one more item. 
            num_query   = num_items[:,1].detach().numpy().flatten()
            batch_sz    = num_items.shape[0]
            
            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x = x.permute(0,2,1) # bx70*20
            
            # x_feat_T: Teacher-net input, x_feat_S: Student-net input(que-log is excluded)
            x_feat_T = torch.zeros(batch_sz, 72, 20)
            x_feat_T[:,:70,:] = x.clone()
            x_feat_T[:, 70,:10] = 1 # Sup/Que state indicator  
            x_feat_T[:, 71,:10] = labels[:,:10].clone()
                        
            x_feat_S = x_feat_T.clone()
            x_feat_S[:, :41, 10:] = 0 # remove que-log
            
            x_feat_T = x_feat_T.cuda(GPU)
            x_feat_S = Variable(x_feat_S).cuda(GPU)
            
            
            # Target: Prepare Teacher's intermediate output 
            enc_target = SMT_Enc(x_feat_T)
            #target = SMT_EncFeat(x_feat_T)
            
            # y
            y = labels.clone()
            
            # y_mask
            y_mask_que = y_mask.clone()
            y_mask_que[:,:10] = 0
            
            # Forward & update
            y_hat_enc, y_hat = SM(x_feat_S) # y_hat: b*20
            
            # Calcultate Distillation loss
            loss1 = F.binary_cross_entropy_with_logits(input=y_hat_enc, target=torch.sigmoid(enc_target.cuda(GPU)))
            loss2 = F.l1_loss(input=y_hat_enc, target=enc_target.cuda(GPU))
            loss = loss1+loss2
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward(retain_graph=True)
            # Update Enc
            SM_optim.step()
            
            # Calculate Classifier loss
            loss_c = F.binary_cross_entropy_with_logits(input=y_hat*y_mask_que.cuda(GPU), target=y.cuda(GPU)*y_mask_que.cuda(GPU))
            SM.zero_grad()
            loss_c.backward()
            # Update Classifier and feature
            SM2_optim.step()
            
            # Decision
            SM.eval();
            y_prob = torch.sigmoid(y_hat*y_mask_que.cuda(GPU)).detach().cpu().numpy() # bx20               
            y_pred = (y_prob[:,10:]>0.5).astype(np.int) # bx10
            y_numpy = labels[:,10:].numpy() # bx10
            # Acc
            total_corrects += np.sum((y_pred==y_numpy)*y_mask_que[:,10:].numpy())
            total_query += np.sum(num_query)
            
            # Restore GPU memory
            del loss, loss_c, y_hat, y_hat_enc
    
            if (session+1)%500 == 0:
                hist_trloss.append(total_trloss/900)
                hist_tracc.append(total_corrects/total_query)
                # Prepare display
                sample_sup = labels[0,(10-num_support[0]):10].long().numpy().flatten() 
                sample_que = y_numpy[0,:num_query[0]].astype(int)
                sample_pred = y_pred[0,:num_query[0]]
                sample_prob = y_prob[0,10:10+num_query[0]]

                tqdm.write("S:" + np.array2string(sample_sup) +'\n'+
                           "Q:" + np.array2string(sample_que) + '\n' +
                           "P:" + np.array2string(sample_pred) + '\n' +
                           "prob:" + np.array2string(sample_prob))
                tqdm.write("tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query    = 0
                total_trloss   = 0
                
            
            if (session+1)%25000 == 0:
                 # Validation
                 validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                 # Save
                 torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                             'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                             'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                    'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True,
        seq_mode=True)  # seq_mode implemented

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 104965071),  #(99965071, 124950714), # 20%를 테스트
        batch_size=TS_BATCH_SZ,
        shuffle=False,
        seq_mode=True)

    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.8)

    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']

    # Train
    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):
            SM.train()
            x, labels, y_mask, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x = x.permute(0, 2, 1)  # bx70*20
            x_sup = Variable(
                torch.cat((x[:, :, :10], labels[:, :10].unsqueeze(1)),
                          1)).cuda(GPU)  # bx71(41+29+1)*10
            x_que = torch.zeros(batch_sz, 72, 20)
            x_que[:, :41, :10] = x[:, :41, :10].clone()  # fill with x_sup_log
            x_que[:, 41:70, :] = x[:, 41:, :].clone(
            )  # fill with x_sup_feat and x_que_feat
            x_que[:, 70, :10] = 1  # support marking
            x_que[:, 71, :10] = labels[:, :10]  # labels marking
            x_que = Variable(x_que).cuda(GPU)  # bx29*10

            # y
            y = labels.clone()  # bx20

            # y_mask
            y_mask_que = y_mask.clone()
            y_mask_que[:, :10] = 0

            # Forward & update
            y_hat, att = SM(x_sup, x_que)  # y_hat: b*20, att: bx10*20

            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(
                input=y_hat * y_mask_que.cuda(GPU),
                target=y.cuda(GPU) * y_mask_que.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()

            # Decision
            y_prob = torch.sigmoid(
                y_hat * y_mask_que.cuda(GPU)).detach().cpu().numpy()  # bx20
            y_pred = (y_prob[:, 10:] > 0.5).astype(np.int)  # bx10
            y_numpy = labels[:, 10:].numpy()  # bx10
            # Acc
            total_corrects += np.sum(
                (y_pred == y_numpy) * y_mask_que[:, 10:].numpy())
            total_query += np.sum(num_query)

            # Restore GPU memory
            del loss, y_hat

            if (session + 1) % 500 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                # Prepare display
                sample_att = att[0, (10 - num_support[0]):10,
                                 (10 - num_support[0]):(
                                     10 +
                                     num_query[0])].detach().cpu().numpy()

                sample_sup = labels[0, (
                    10 - num_support[0]):10].long().numpy().flatten()
                sample_que = y_numpy[0, :num_query[0]].astype(int)
                sample_pred = y_pred[0, :num_query[0]]
                sample_prob = y_prob[0, 10:10 + num_query[0]]

                tqdm.write(
                    np.array2string(sample_att,
                                    formatter={
                                        'float_kind':
                                        lambda sample_att: "%.2f" % sample_att
                                    }).replace('\n ', '').replace(
                                        '][', ']\n[').replace('[[', '['))
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))
                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 25000 == 0:
                # Validation
                validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'SM_state': SM.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'SM_opt_state': SM_optim.state_dict(),
                        'SM_sch_state': SM_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'SM_state': SM.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'SM_opt_state': SM_optim.state_dict(),
                'SM_sch_state': SM_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()