Пример #1
0
def main():  
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True,
                                      data_sel=(0, 99965071), # 80% 트레인
                                      batch_size=TR_BATCH_SZ,
                                      shuffle=True,
                                      seq_mode=True) # seq_mode implemented  
    
    mval_loader  = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True, # True, because we use part of trainset as testset
                                      data_sel=(99965071, 104965071),#(99965071, 124950714), # 20%를 테스트
                                      batch_size=TS_BATCH_SZ,
                                      shuffle=False,
                                      seq_mode=True) 
    
    # Load Teacher net
    SMT = SeqModel().cuda(GPU) 
    checkpoint = torch.load(FPATH_T_NET_CHECKPOINT, map_location='cuda:{}'.format(GPU))
    tqdm.write("Loading saved teacher model from '{0:}'... loss: {1:.6f}".format(FPATH_T_NET_CHECKPOINT,checkpoint['loss']))
    SMT.load_state_dict(checkpoint['SM_state'])
    
    SMT_Enc  = nn.Sequential(*list(SMT.children())[:1]).cuda(GPU)
    #SMT_EncFeat = nn.Sequential(*list(SMT.children())[:2])
    
    
    # Init Student net --> copy classifier from the Teacher net
    SM = SeqModel_Student().cuda(GPU)
    SM.feature = deepcopy(SMT.feature)
#    for p in list(SM.feature.parameters()):
#        p.requires_grad = False
    SM.classifier = deepcopy(SMT.classifier)
#    SM.classifier.weight.requires_grad = False
#    SM.classifier.bias.requires_grad = False
    SM = SM.cuda(GPU)
    Distill_parameters = SM.enc.parameters()
    Classifier_parameters = [{'params': SM.feature.parameters()},
                              {'params': SM.classifier.parameters()}]
    
    SM_optim = torch.optim.Adam(Distill_parameters, lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.9)  
    SM2_optim = torch.optim.Adam(Classifier_parameters, lr=LEARNING_RATE)
    
    
    
    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0        
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),key=os.path.getctime)  
        checkpoint = torch.load(latest_fpath, map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(latest_fpath,checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']
        
    # Train    
    for epoch in trange(START_EPOCH, EPOCHS, desc='epochs', position=0, ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query    = 0
        total_trloss   = 0
        for session in trange(len(tr_sessions_iter), desc='sessions', position=1, ascii=True):
            SMT.eval(); # Teacher-net
            SM.train(); # Student-net
            x, labels, y_mask, num_items, index = tr_sessions_iter.next() # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS 
            
            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...        
            num_support = num_items[:,0].detach().numpy().flatten() # If num_items was odd number, query has one more item. 
            num_query   = num_items[:,1].detach().numpy().flatten()
            batch_sz    = num_items.shape[0]
            
            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x = x.permute(0,2,1) # bx70*20
            
            # x_feat_T: Teacher-net input, x_feat_S: Student-net input(que-log is excluded)
            x_feat_T = torch.zeros(batch_sz, 72, 20)
            x_feat_T[:,:70,:] = x.clone()
            x_feat_T[:, 70,:10] = 1 # Sup/Que state indicator  
            x_feat_T[:, 71,:10] = labels[:,:10].clone()
                        
            x_feat_S = x_feat_T.clone()
            x_feat_S[:, :41, 10:] = 0 # remove que-log
            
            x_feat_T = x_feat_T.cuda(GPU)
            x_feat_S = Variable(x_feat_S).cuda(GPU)
            
            
            # Target: Prepare Teacher's intermediate output 
            enc_target = SMT_Enc(x_feat_T)
            #target = SMT_EncFeat(x_feat_T)
            
            # y
            y = labels.clone()
            
            # y_mask
            y_mask_que = y_mask.clone()
            y_mask_que[:,:10] = 0
            
            # Forward & update
            y_hat_enc, y_hat = SM(x_feat_S) # y_hat: b*20
            
            # Calcultate Distillation loss
            loss1 = F.binary_cross_entropy_with_logits(input=y_hat_enc, target=torch.sigmoid(enc_target.cuda(GPU)))
            loss2 = F.l1_loss(input=y_hat_enc, target=enc_target.cuda(GPU))
            loss = loss1+loss2
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward(retain_graph=True)
            # Update Enc
            SM_optim.step()
            
            # Calculate Classifier loss
            loss_c = F.binary_cross_entropy_with_logits(input=y_hat*y_mask_que.cuda(GPU), target=y.cuda(GPU)*y_mask_que.cuda(GPU))
            SM.zero_grad()
            loss_c.backward()
            # Update Classifier and feature
            SM2_optim.step()
            
            # Decision
            SM.eval();
            y_prob = torch.sigmoid(y_hat*y_mask_que.cuda(GPU)).detach().cpu().numpy() # bx20               
            y_pred = (y_prob[:,10:]>0.5).astype(np.int) # bx10
            y_numpy = labels[:,10:].numpy() # bx10
            # Acc
            total_corrects += np.sum((y_pred==y_numpy)*y_mask_que[:,10:].numpy())
            total_query += np.sum(num_query)
            
            # Restore GPU memory
            del loss, loss_c, y_hat, y_hat_enc
    
            if (session+1)%500 == 0:
                hist_trloss.append(total_trloss/900)
                hist_tracc.append(total_corrects/total_query)
                # Prepare display
                sample_sup = labels[0,(10-num_support[0]):10].long().numpy().flatten() 
                sample_que = y_numpy[0,:num_query[0]].astype(int)
                sample_pred = y_pred[0,:num_query[0]]
                sample_prob = y_prob[0,10:10+num_query[0]]

                tqdm.write("S:" + np.array2string(sample_sup) +'\n'+
                           "Q:" + np.array2string(sample_que) + '\n' +
                           "P:" + np.array2string(sample_pred) + '\n' +
                           "prob:" + np.array2string(sample_prob))
                tqdm.write("tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query    = 0
                total_trloss   = 0
                
            
            if (session+1)%25000 == 0:
                 # Validation
                 validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                 # Save
                 torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                             'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                             'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                    'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
Пример #2
0
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True)  # shuffle은 True로 해야됨 나중에...

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 124950714),  # 20%를 테스트
        batch_size=2048,
        shuffle=False)

    # Init neural net
    #FeatEnc = MLP(input_sz=29, hidden_sz=512, output_sz=64).apply(weights_init).cuda(GPU)
    FeatEnc = MLP(input_sz=29, hidden_sz=256, output_sz=64).cuda(GPU)
    RN = RelationNetwork().cuda(GPU)

    FeatEnc_optim = torch.optim.Adam(FeatEnc.parameters(), lr=LEARNING_RATE)
    RN_optim = torch.optim.Adam(RN.parameters(), lr=LEARNING_RATE)

    FeatEnc_scheduler = StepLR(FeatEnc_optim, step_size=100000, gamma=0.2)
    RN_scheduler = StepLR(RN_optim, step_size=100000, gamma=0.2)

    if args.load_continue_latest is None:
        START_EPOCH = 0

    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['hist_trloss'][-1]))
        FeatEnc.load_state_dict(checkpoint['FE_state'])
        RN.load_state_dict(checkpoint['RN_state'])
        FeatEnc_optim.load_state_dict(checkpoint['FE_opt_state'])
        RN_optim.load_state_dict(checkpoint['RN_opt_state'])
        FeatEnc_scheduler.load_state_dict(checkpoint['FE_sch_state'])
        RN_scheduler.load_state_dict(checkpoint['RN_sch_state'])
        START_EPOCH = checkpoint['ep']

    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):

        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):

            FeatEnc.train()
            RN.train()
            x_sup, x_que, x_log_sup, x_log_que, label_sup, label_que, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS
            x_sup, x_que = Variable(x_sup).cuda(GPU), Variable(x_que).cuda(GPU)
            x_log_sup, x_log_que = Variable(x_log_sup).cuda(GPU), Variable(
                x_log_que).cuda(GPU)
            label_sup = Variable(label_sup).cuda(GPU)

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            x_sup = x_sup.unsqueeze(2)  # 1x7*29 --> 1x7x1*29
            x_que = x_que.unsqueeze(2)  # 1x8*29 --> 1x8x1*29

            # - feature encoder
            x_feat_sup = FeatEnc(x_sup)  # 1x7x1*64
            x_feat_que = FeatEnc(x_que)  # 1x8x1*64

            # - relation network
            y_hat = RN(x_feat_sup, x_feat_que, x_log_sup, x_log_que,
                       label_sup)  # bx8

            # Prepare ground-truth simlarity score and mask
            y_gt = label_que[:, :, 1]
            y_mask = np.zeros((batch_sz, 10), dtype=np.float32)
            for b in np.arange(batch_sz):
                y_mask[b, :num_query[b]] = 1
            y_mask = torch.FloatTensor(y_mask).cuda(GPU)

            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(input=y_hat * y_mask,
                                                      target=y_gt.cuda(GPU) *
                                                      y_mask)
            total_trloss += loss.item()

            # Update Nets
            FeatEnc.zero_grad()
            RN.zero_grad()

            loss.backward()
            #torch.nn.utils.clip_grad_norm_(FeatEnc.parameters(), 0.5)
            #torch.nn.utils.clip_grad_norm_(RN.parameters(), 0.5)

            FeatEnc_optim.step()
            RN_optim.step()

            # Decision
            y_prob = (torch.sigmoid(y_hat) * y_mask).detach().cpu().numpy()
            y_pred = ((torch.sigmoid(y_hat) > 0.5).float() *
                      y_mask).detach().cpu().long().numpy()

            # Prepare display
            sample_sup = label_sup[0, :num_support[0],
                                   1].detach().long().cpu().numpy().flatten()
            sample_que = label_que[0, :num_query[0],
                                   1].long().numpy().flatten()
            sample_pred = y_pred[0, :num_query[0]].flatten()
            sample_prob = y_prob[0, :num_query[0]].flatten()

            # Acc
            total_corrects += np.sum(
                (y_pred == label_que[:, :, 1].long().numpy()) *
                y_mask.cpu().numpy())
            total_query += np.sum(num_query)

            # Restore GPU memory
            del loss, x_feat_sup, x_feat_que, y_hat

            if (session + 1) % 900 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))

                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 4000 == 0:
                # Validation
                validate(mval_loader, FeatEnc, RN, eval_mode=True)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'FE_state': FeatEnc.state_dict(),
                        'RN_state': RN.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'FE_opt_state': FeatEnc_optim.state_dict(),
                        'RN_opt_state': RN_optim.state_dict(),
                        'FE_sch_state': FeatEnc_scheduler.state_dict(),
                        'RN_sch_state': RN_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))

        # Validation
        validate(mval_loader, FeatEnc, RN, eval_mode=True)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'FE_state': FeatEnc.state_dict(),
                'RN_state': RN.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'FE_opt_state': FeatEnc_optim.state_dict(),
                'RN_opt_state': RN_optim.state_dict(),
                'FE_sch_state': FeatEnc_scheduler.state_dict(),
                'RN_sch_state': RN_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
def main():  
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True,
                                      data_sel=(0, 99965071), # 80% 트레인
                                      batch_size=TR_BATCH_SZ,
                                      shuffle=True,
                                      seq_mode=True) # seq_mode implemented  
    
    mval_loader  = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True, # True, because we use part of trainset as testset
                                      data_sel=(99965071, 101065071),#104965071),#(99965071, 124950714), # 20%를 테스트
                                      batch_size=TS_BATCH_SZ,
                                      shuffle=False,
                                      seq_mode=True) 
    
    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.8)  
    
    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0        
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),key=os.path.getctime)  
        checkpoint = torch.load(latest_fpath, map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(latest_fpath,checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']
        
    # Train    
    for epoch in trange(START_EPOCH, EPOCHS, desc='epochs', position=0, ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query    = 0
        total_trloss_qlog = 0
        total_trloss_skip = 0
        total_trloss   = 0
        for session in trange(len(tr_sessions_iter), desc='sessions', position=1, ascii=True):
            SM.train();
            x, labels, y_mask, num_items, index = tr_sessions_iter.next() # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS
            
            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...        
            num_support = num_items[:,0].detach().numpy().flatten() # If num_items was odd number, query has one more item. 
            num_query   = num_items[:,1].detach().numpy().flatten()
            batch_sz    = num_items.shape[0]
    
            # x: bx70*20
            x = x.permute(0,2,1)
            
            # Prepare ground truth log and label, y
            y_qlog = x[:,:41,:].clone() # bx41*20
            y_skip = labels.clone() #bx20
            y_mask_qlog = y_mask.unsqueeze(1).repeat(1,41,1) #bx41*20
            y_mask_skip = y_mask #bx20
    
            # log shift: bx41*20
            log_shift = torch.zeros(batch_sz,41,20)
            log_shift[:,:,1:] = x[:,:41,:-1]
            log_shift[:,:,11:] = 0 # DELETE LOG QUE
            
            # labels_shift: bx1*20(model can only observe past labels)
            labels_shift = torch.zeros(batch_sz,1,20)
            labels_shift[:,0,1:] = labels[:,:-1].float()
            labels_shift[:,0,11:] = 0 #!!! NOLABEL for previous QUERY
            
            # support/query state labels: bx1*20
            sq_state = torch.zeros(batch_sz,1,20)
            sq_state[:,0,:11] = 1
            
            # Pack x: bx72*20 (or bx32*20 if not using sup_logs)
            x = Variable(torch.cat((log_shift, x[:,41:,:], labels_shift, sq_state), 1)).cuda(GPU) # x: bx72*20
  
            # Forward & update
            y_hat_qlog, y_hat_skip = SM(x) # y_hat: b*20
            
            # Calcultate BCE loss
            loss_qlog = F.binary_cross_entropy_with_logits(input=y_hat_qlog.cuda(GPU)*y_mask_qlog.cuda(GPU),
                                                           target=y_qlog.cuda(GPU)*y_mask_qlog.cuda(GPU))
            loss_skip = F.binary_cross_entropy_with_logits(input=y_hat_skip.cuda(GPU)*y_mask_skip.cuda(GPU),
                                                           target=y_skip.cuda(GPU)*y_mask_skip.cuda(GPU))
            loss      = loss_qlog + loss_skip
            total_trloss_qlog += loss_qlog.item()
            total_trloss_skip += loss_skip.item()
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()
            
            # Decision
            y_prob = torch.sigmoid(y_hat_skip.detach()*y_mask_skip.cuda(GPU)).cpu().numpy() # bx20               
            y_pred = (y_prob[:,10:]>=0.5).astype(np.int) # bx10
            y_numpy = y_skip[:,10:].numpy() # bx10
            
            # Label Acc*
            total_corrects += np.sum((y_pred==y_numpy)*y_mask_skip[:,10:].numpy())
            total_query += np.sum(num_query)
#            # Log generation Acc*
#            y_qlog_mask = y_mask[:,:41,10:]
            
            # Restore GPU memory
            del loss, loss_qlog, loss_skip, y_hat_qlog, y_hat_skip 
    
            if (session+1)%500 == 0:
                hist_trloss_qlog.append(total_trloss_qlog/500) #!
                hist_trloss_skip.append(total_trloss_skip/500) #!
                hist_trloss.append(total_trloss/500)
                hist_tracc.append(total_corrects/total_query)
                # Prepare display
                sample_sup = labels[0,(10-num_support[0]):10].long().numpy().flatten() 
                sample_que = y_numpy[0,:num_query[0]].astype(int)
                sample_pred = y_pred[0,:num_query[0]]
                sample_prob = y_prob[0,10:10+num_query[0]]
                tqdm.write("S:" + np.array2string(sample_sup) +'\n'+
                           "Q:" + np.array2string(sample_que) + '\n' +
                           "P:" + np.array2string(sample_pred) + '\n' +
                           "prob:" + np.array2string(sample_prob))
                tqdm.write("tr_session:{0:}  tr_loss(qlog|skip):{1:.6f}({2:.6f}|{3:.6f})  tr_acc:{4:.4f}".format(session,
                           hist_trloss[-1], hist_trloss_qlog[-1], hist_trloss_skip[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query    = 0
                total_trloss   = 0
                total_trloss_qlog   = 0
                total_trloss_skip   = 0
            
            if (session+1)%8000 == 0:
                 # Validation
                 validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                 # Save
                 torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 
                             'hist_trloss_qlog': hist_trloss_qlog, 'hist_trloss_skip': hist_trloss_skip,  'hist_vacc': hist_vacc,
                             'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                             'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1],
                    'hist_trloss_qlog': hist_trloss_qlog, 'hist_trloss_skip': hist_trloss_skip,  'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                    'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
Пример #4
0
sys.argv = [sys.argv[0]]

GPU = args.gpu
# Module (path of train code including validate() and model)
MODEL_PATH = args.model_py
# Checkpoint directory
CHECKPOINT_PATH = args.save_path
# Submission output
SUBMISSION_OUTPUT_PATH = os.path.dirname(CHECKPOINT_PATH)

# TSSET stats: 518275860 items within 31251398 sessions
print('Initializing dataloader...')
mtest_loader = SpotifyDataloader(
    config_fpath=args.config,
    mtrain_mode=False,  # False = testset for submission
    #data_sel=(0, 100),
    batch_size=4096,
    shuffle=False)

#mtest_loader  = SpotifyDataloader(config_fpath=args.config,
#                                  mtrain_mode=True, # True, because we use part of trainset as testset
#                                  data_sel=(99965071, 110075071),#(99965071, 124950714), # 20%를 테스트
#                                  batch_size=10,
#                                  shuffle=True)


def save_submission(output, output_path):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w") as f:
        for l in output:
            line = ''.join(map(str, l))
Пример #5
0
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True,
        seq_mode=True)  # seq_mode implemented

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 104965071),  #(99965071, 124950714), # 20%를 테스트
        batch_size=TS_BATCH_SZ,
        shuffle=False,
        seq_mode=True)

    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.7)

    CF_model = MLP_Regressor().cuda(GPU)
    CF_checkpoint = torch.load(CF_CHECKPOINT_PATH,
                               map_location='cuda:{}'.format(GPU))
    CF_model.load_state_dict(CF_checkpoint['model_state'])

    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']

    # Train
    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):
            SM.train()
            x, labels, y_mask, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x[:, 10:, :41] = 0  # DELETE METALOG QUE

            # labels_shift: (model can only observe past labels)
            labels_shift = torch.zeros(batch_sz, 20, 1)
            labels_shift[:, 1:, 0] = labels[:, :-1].float()
            #!!! NOLABEL for previous QUERY
            labels_shift[:, 11:, 0] = 0
            # support/query state labels
            sq_state = torch.zeros(batch_sz, 20, 1)
            sq_state[:, :11, 0] = 1
            # compute lastfm_output
            x_audio = x[:, :, 41:].data.clone()
            x_audio = Variable(x_audio, requires_grad=False).cuda(GPU)
            x_emb_lastfm, x_lastfm = CF_model(x_audio)
            x_lastfm = x_lastfm.cpu()
            del x_emb_lastfm

            # Pack x: bx122*20
            x = Variable(
                torch.cat((x_lastfm, x, labels_shift, sq_state),
                          dim=2).permute(0, 2, 1)).cuda(GPU)

            # Forward & update
            y_hat = SM(x)  # y_hat: b*20
            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(
                input=y_hat * y_mask.cuda(GPU),
                target=labels.cuda(GPU) * y_mask.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()

            # Decision
            y_prob = torch.sigmoid(
                y_hat * y_mask.cuda(GPU)).detach().cpu().numpy()  # bx20
            y_pred = (y_prob[:, 10:] >= 0.5).astype(np.int)  # bx10
            y_numpy = labels[:, 10:].numpy()  # bx10
            # Acc
            y_query_mask = y_mask[:, 10:].numpy()
            total_corrects += np.sum((y_pred == y_numpy) * y_query_mask)
            total_query += np.sum(num_query)
            # Restore GPU memory
            del loss, y_hat

            if (session + 1) % 500 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                # Prepare display
                sample_sup = labels[
                    0, :num_support[0]].long().numpy().flatten()
                sample_que = y_numpy[0, :num_query[0]].astype(int)
                sample_pred = y_pred[0, :num_query[0]]
                sample_prob = y_prob[0, 10:10 + num_query[0]]
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))
                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 20000 == 0:
                # Validation
                validate(mval_loader, SM, CF_model, eval_mode=True)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'SM_state': SM.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'SM_opt_state': SM_optim.state_dict(),
                        'SM_sch_state': SM_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, CF_model, eval_mode=True)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'SM_state': SM.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'SM_opt_state': SM_optim.state_dict(),
                'SM_sch_state': SM_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
Пример #6
0
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True,
        seq_mode=True)  # seq_mode implemented

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 104965071),  #(99965071, 124950714), # 20%를 테스트
        batch_size=TS_BATCH_SZ,
        shuffle=False,
        seq_mode=True)

    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.8)

    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']

    # Train
    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):
            SM.train()
            x, labels, y_mask, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x = x.permute(0, 2, 1)  # bx70*20
            x_sup = Variable(
                torch.cat((x[:, :, :10], labels[:, :10].unsqueeze(1)),
                          1)).cuda(GPU)  # bx71(41+29+1)*10
            x_que = torch.zeros(batch_sz, 72, 20)
            x_que[:, :41, :10] = x[:, :41, :10].clone()  # fill with x_sup_log
            x_que[:, 41:70, :] = x[:, 41:, :].clone(
            )  # fill with x_sup_feat and x_que_feat
            x_que[:, 70, :10] = 1  # support marking
            x_que[:, 71, :10] = labels[:, :10]  # labels marking
            x_que = Variable(x_que).cuda(GPU)  # bx29*10

            # y
            y = labels.clone()  # bx20

            # y_mask
            y_mask_que = y_mask.clone()
            y_mask_que[:, :10] = 0

            # Forward & update
            y_hat, att = SM(x_sup, x_que)  # y_hat: b*20, att: bx10*20

            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(
                input=y_hat * y_mask_que.cuda(GPU),
                target=y.cuda(GPU) * y_mask_que.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()

            # Decision
            y_prob = torch.sigmoid(
                y_hat * y_mask_que.cuda(GPU)).detach().cpu().numpy()  # bx20
            y_pred = (y_prob[:, 10:] > 0.5).astype(np.int)  # bx10
            y_numpy = labels[:, 10:].numpy()  # bx10
            # Acc
            total_corrects += np.sum(
                (y_pred == y_numpy) * y_mask_que[:, 10:].numpy())
            total_query += np.sum(num_query)

            # Restore GPU memory
            del loss, y_hat

            if (session + 1) % 500 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                # Prepare display
                sample_att = att[0, (10 - num_support[0]):10,
                                 (10 - num_support[0]):(
                                     10 +
                                     num_query[0])].detach().cpu().numpy()

                sample_sup = labels[0, (
                    10 - num_support[0]):10].long().numpy().flatten()
                sample_que = y_numpy[0, :num_query[0]].astype(int)
                sample_pred = y_pred[0, :num_query[0]]
                sample_prob = y_prob[0, 10:10 + num_query[0]]

                tqdm.write(
                    np.array2string(sample_att,
                                    formatter={
                                        'float_kind':
                                        lambda sample_att: "%.2f" % sample_att
                                    }).replace('\n ', '').replace(
                                        '][', ']\n[').replace('[[', '['))
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))
                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 25000 == 0:
                # Validation
                validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'SM_state': SM.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'SM_opt_state': SM_optim.state_dict(),
                        'SM_sch_state': SM_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'SM_state': SM.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'SM_opt_state': SM_optim.state_dict(),
                'SM_sch_state': SM_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
Пример #7
0
CLASS_NUM = args.class_num
EPOCHS = args.epochs
TEST_EPISODE = args.test_episode
LEARNING_RATE = args.learning_rate
TR_BATCH_SZ = args.train_batch_size
GPU = args.gpu

# Model-save directory
MODEL_SAVE_PATH = args.save_path
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)

# Trainset stats: 2072002577 items from 124950714 sessions
print('Initializing dataloader...')
mtrain_loader = SpotifyDataloader(
    config_fpath=args.config,
    mtrain_mode=True,
    data_sel=(0, 99965071),  # 80% 트레인
    batch_size=TR_BATCH_SZ,
    shuffle=True)  # shuffle은 True로 해야됨 나중에...

mval_loader = SpotifyDataloader(
    config_fpath=args.config,
    mtrain_mode=True,  # True, because we use part of trainset as testset
    data_sel=(99965071, 124950714),  #(99965071, 124950714), # 20%를 테스트
    batch_size=2048,
    shuffle=True)


#Feature encoder:
class MLP(nn.Module):
    def __init__(self, input_sz, hidden_sz, output_sz):
        super(MLP, self).__init__()
Пример #8
0
parser.add_argument("-g", "--gpu", type=int, default=0)
#parser.add_argument("-e","--embed_hidden_unit",type=int, default=2)
args = parser.parse_args()

# Hyper Parameters
FEATURE_DIM = args.feature_dim
RELATION_DIM = args.relation_dim
CLASS_NUM = args.class_num
EPOCHS = args.epochs
TEST_EPISODE = args.test_episode
LEARNING_RATE = args.learning_rate
GPU = args.gpu

mtrain_loader = SpotifyDataloader(
    config_fpath=args.config,
    mtrain_mode=True,  # True, because we use part of trainset as testset
    data_sel=(0, 99965071),  #(99965071, 124950714), # 20%를 테스트
    batch_size=1,
    shuffle=False)

#%%
train_sessions_iter = iter(mtrain_loader)
for train_session in trange(len(train_sessions_iter),
                            desc='varify',
                            position=0):
    feats, labels, num_items, index = train_sessions_iter.next()
    if train_session == 0:
        a0f = feats.detach().numpy()
        a0l = labels.detach().numpy()
        #tqdm.write(np.array2string(a0f))
        tqdm.write(np.array2string(a0l))