예제 #1
0
def BSN_Train_PEM(opt):
    writer = SummaryWriter()
    model = PEM(opt)
    model = torch.nn.DataParallel(model, device_ids=GPU_IDs).cuda()
    
    optimizer = optim.Adam(model.parameters(),lr=opt["pem_training_lr"],weight_decay = opt["pem_weight_decay"])
    
    def collate_fn(batch):
        batch_data = torch.cat([x[0] for x in batch])
        batch_iou = torch.cat([x[1][0] for x in batch])
        batch_is_whole_lenght = torch.cat([x[1][1] for x in batch])
        return batch_data,batch_iou, batch_is_whole_lenght
    
    train_loader = torch.utils.data.DataLoader(ProposalDataSet(opt,subset="train"),
                                                batch_size=model.module.batch_size, shuffle=True,
                                                num_workers=8, pin_memory=True,drop_last=True,collate_fn=collate_fn)            
    
    test_loader = torch.utils.data.DataLoader(ProposalDataSet(opt,subset="validation"),
                                                batch_size=model.module.batch_size, shuffle=True,
                                                num_workers=8, pin_memory=True,drop_last=True,collate_fn=collate_fn)
        
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = opt["pem_step_size"], gamma = opt["pem_step_gamma"])
        
    for epoch in range(opt["pem_epoch"]):
        scheduler.step()
        train_PEM(train_loader,model,optimizer,epoch,writer,opt)
        test_PEM(test_loader,model,epoch,writer,opt)
        
    writer.close()
def BSN_Train_PEM(opt):
    model = PEM(opt)
    model = torch.nn.DataParallel(model).cuda()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt["pem_training_lr"],
                           weight_decay=opt["pem_weight_decay"])

    print('Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    def collate_fn(batch):
        batch_data = torch.cat([x[0] for x in batch])
        batch_iou = torch.cat([x[1] for x in batch])
        return batch_data, batch_iou

    train_dataset = ProposalDataSet(opt, subset="train")
    train_sampler = ProposalSampler(train_dataset.proposals,
                                    train_dataset.indices,
                                    max_zero_weight=opt['pem_max_zero_weight'])

    global_step = 0
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=model.module.batch_size,
        shuffle=False,
        sampler=train_sampler,
        num_workers=opt['data_workers'],
        pin_memory=True,
        drop_last=False,
        collate_fn=collate_fn if not opt['pem_do_index'] else None)

    subset = "validation" if opt['dataset'] == 'activitynet' else "test"
    test_loader = torch.utils.data.DataLoader(
        ProposalDataSet(opt, subset=subset),
        batch_size=model.module.batch_size,
        shuffle=True,
        num_workers=opt['data_workers'],
        pin_memory=True,
        drop_last=False,
        collate_fn=collate_fn if not opt['pem_do_index'] else None)

    milestones = [int(k) for k in opt['pem_lr_milestones'].split(',')]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=opt['pem_step_gamma'])

    if opt['log_to_comet']:
        comet_exp = CometExperiment(api_key="hIXq6lDzWzz24zgKv7RYz6blo",
                                    project_name="bsnpem",
                                    workspace="cinjon",
                                    auto_metric_logging=True,
                                    auto_output_logging=None,
                                    auto_param_logging=False)
    elif opt['local_comet_dir']:
        comet_exp = OfflineExperiment(api_key="hIXq6lDzWzz24zgKv7RYz6blo",
                                      project_name="bsnpem",
                                      workspace="cinjon",
                                      auto_metric_logging=True,
                                      auto_output_logging=None,
                                      auto_param_logging=False,
                                      offline_directory=opt['local_comet_dir'])
    else:
        comet_exp = None

    if comet_exp:
        comet_exp.log_parameters(opt)
        comet_exp.set_name(opt['name'])

    test_PEM(test_loader, model, -1, -1, comet_exp, opt)
    for epoch in range(opt["pem_epoch"]):
        global_step = train_PEM(train_loader, model, optimizer, epoch,
                                global_step, comet_exp, opt)
        test_PEM(test_loader, model, epoch, global_step, comet_exp, opt)
        scheduler.step()