def test_optimizer_scheduler_loader(self):
        from torch.optim.lr_scheduler import StepLR

        _scheduler = self.test_optimizer_scheduler_saver()

        model = Model()
        optimizer = optim.Adam(model.parameters(), lr=0.01)
        scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

        self.optimizer_manager.load_with_scheduler(optimizer, scheduler, step=938)
        self.assertEqual(_scheduler.state_dict(), scheduler.state_dict())
示例#2
0
def train(model, optimizer, loss_fn, epochs, train_loader, device, model_chckpt_path, checkpoint_save_interval,
          model_path, load_chckpt, log_interval):
    epoch_start = 0

    scheduler = StepLR(optimizer, int(epochs * 0.5), 0.1)

    if load_chckpt and os.path.isfile(model_chckpt_path):
        checkpoint = torch.load(model_chckpt_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        epoch_start = checkpoint['epoch']
        print("Training checkpoints found. Starting training from epoch %d." % epoch_start)

    model.train()
    for epoch in range(epoch_start, epochs):
        running_loss = 0.0
        processed_items = 0
        correct_predictions = 0
        for batch_num, (images, targets) in enumerate(train_loader):
            images, targets = images.to(device), targets.to(device)
            out = model(images)
            optimizer.zero_grad()
            loss = loss_fn(out, targets)
            loss.backward()
            optimizer.step()

            _, correct = calculate_correct_predictions(targets, out)
            running_loss += loss.item()
            processed_items += out.size()[0]
            correct_predictions += correct

            if (batch_num + 1) % log_interval == 0:
                print('[Epoch %d, Batch %4d] Loss: %.10f, Accuracy: %.5f' %
                      (epoch + 1, batch_num + 1, running_loss / processed_items, correct_predictions / processed_items))

        if epoch % checkpoint_save_interval == 0:
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()}, model_chckpt_path)
    torch.save(model.state_dict(), model_path)
示例#3
0
def main_mlp():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--num_mlp_layers',
                        type=int,
                        default=6,
                        help='number of mlp layers (default: 6)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0.2,
                        help='dropout ratio (default: 0.2)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=1600,
                        help='embedding dimensionality (default: 1600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--radius',
                        type=int,
                        default=2,
                        help='radius (default: 2)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    random.seed(42)

    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    dataset = PCQM4MDataset(root='dataset/', only_smiles=True)
    fp_processed_file = preprocess_fp(dataset, args.radius)

    data_dict = torch.load(fp_processed_file)
    X, Y = data_dict['X'], data_dict['Y']

    split_idx = dataset.get_idx_split()
    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        print('train subset')
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_dataset = TensorDataset(X[split_idx['train'][subset_idx]],
                                      Y[split_idx['train'][subset_idx]])

    else:
        train_dataset = TensorDataset(X[split_idx['train']],
                                      Y[split_idx['train']])

    valid_dataset = TensorDataset(X[split_idx['valid']], Y[split_idx['valid']])
    test_dataset = TensorDataset(X[split_idx['test-dev']],
                                 Y[split_idx['test']])

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)

    if args.save_test_dir != '':
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    if args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    model = MLP(num_mlp_layers=args.num_mlp_layers,
                emb_dim=args.emb_dim,
                drop_ratio=args.drop_ratio).to(device)

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir != '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, device, train_loader, optimizer)

        print('Evaluating...')
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir != '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir != '':
                print('Saving checkpoint...')
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           osp.join(args.checkpoint_dir, 'checkpoint.pt'))

            if args.save_test_dir != '':
                print('Predicting on test data...')
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir,
                                               mode='test-dev')

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

    if args.log_dir != '':
        writer.close()
def main():  
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True,
                                      data_sel=(0, 99965071), # 80% 트레인
                                      batch_size=TR_BATCH_SZ,
                                      shuffle=True,
                                      seq_mode=True) # seq_mode implemented  
    
    mval_loader  = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True, # True, because we use part of trainset as testset
                                      data_sel=(99965071, 101065071),#104965071),#(99965071, 124950714), # 20%를 테스트
                                      batch_size=TS_BATCH_SZ,
                                      shuffle=False,
                                      seq_mode=True) 
    
    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.8)  
    
    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0        
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),key=os.path.getctime)  
        checkpoint = torch.load(latest_fpath, map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(latest_fpath,checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']
        
    # Train    
    for epoch in trange(START_EPOCH, EPOCHS, desc='epochs', position=0, ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query    = 0
        total_trloss_qlog = 0
        total_trloss_skip = 0
        total_trloss   = 0
        for session in trange(len(tr_sessions_iter), desc='sessions', position=1, ascii=True):
            SM.train();
            x, labels, y_mask, num_items, index = tr_sessions_iter.next() # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS
            
            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...        
            num_support = num_items[:,0].detach().numpy().flatten() # If num_items was odd number, query has one more item. 
            num_query   = num_items[:,1].detach().numpy().flatten()
            batch_sz    = num_items.shape[0]
    
            # x: bx70*20
            x = x.permute(0,2,1)
            
            # Prepare ground truth log and label, y
            y_qlog = x[:,:41,:].clone() # bx41*20
            y_skip = labels.clone() #bx20
            y_mask_qlog = y_mask.unsqueeze(1).repeat(1,41,1) #bx41*20
            y_mask_skip = y_mask #bx20
    
            # log shift: bx41*20
            log_shift = torch.zeros(batch_sz,41,20)
            log_shift[:,:,1:] = x[:,:41,:-1]
            log_shift[:,:,11:] = 0 # DELETE LOG QUE
            
            # labels_shift: bx1*20(model can only observe past labels)
            labels_shift = torch.zeros(batch_sz,1,20)
            labels_shift[:,0,1:] = labels[:,:-1].float()
            labels_shift[:,0,11:] = 0 #!!! NOLABEL for previous QUERY
            
            # support/query state labels: bx1*20
            sq_state = torch.zeros(batch_sz,1,20)
            sq_state[:,0,:11] = 1
            
            # Pack x: bx72*20 (or bx32*20 if not using sup_logs)
            x = Variable(torch.cat((log_shift, x[:,41:,:], labels_shift, sq_state), 1)).cuda(GPU) # x: bx72*20
  
            # Forward & update
            y_hat_qlog, y_hat_skip = SM(x) # y_hat: b*20
            
            # Calcultate BCE loss
            loss_qlog = F.binary_cross_entropy_with_logits(input=y_hat_qlog.cuda(GPU)*y_mask_qlog.cuda(GPU),
                                                           target=y_qlog.cuda(GPU)*y_mask_qlog.cuda(GPU))
            loss_skip = F.binary_cross_entropy_with_logits(input=y_hat_skip.cuda(GPU)*y_mask_skip.cuda(GPU),
                                                           target=y_skip.cuda(GPU)*y_mask_skip.cuda(GPU))
            loss      = loss_qlog + loss_skip
            total_trloss_qlog += loss_qlog.item()
            total_trloss_skip += loss_skip.item()
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()
            
            # Decision
            y_prob = torch.sigmoid(y_hat_skip.detach()*y_mask_skip.cuda(GPU)).cpu().numpy() # bx20               
            y_pred = (y_prob[:,10:]>=0.5).astype(np.int) # bx10
            y_numpy = y_skip[:,10:].numpy() # bx10
            
            # Label Acc*
            total_corrects += np.sum((y_pred==y_numpy)*y_mask_skip[:,10:].numpy())
            total_query += np.sum(num_query)
#            # Log generation Acc*
#            y_qlog_mask = y_mask[:,:41,10:]
            
            # Restore GPU memory
            del loss, loss_qlog, loss_skip, y_hat_qlog, y_hat_skip 
    
            if (session+1)%500 == 0:
                hist_trloss_qlog.append(total_trloss_qlog/500) #!
                hist_trloss_skip.append(total_trloss_skip/500) #!
                hist_trloss.append(total_trloss/500)
                hist_tracc.append(total_corrects/total_query)
                # Prepare display
                sample_sup = labels[0,(10-num_support[0]):10].long().numpy().flatten() 
                sample_que = y_numpy[0,:num_query[0]].astype(int)
                sample_pred = y_pred[0,:num_query[0]]
                sample_prob = y_prob[0,10:10+num_query[0]]
                tqdm.write("S:" + np.array2string(sample_sup) +'\n'+
                           "Q:" + np.array2string(sample_que) + '\n' +
                           "P:" + np.array2string(sample_pred) + '\n' +
                           "prob:" + np.array2string(sample_prob))
                tqdm.write("tr_session:{0:}  tr_loss(qlog|skip):{1:.6f}({2:.6f}|{3:.6f})  tr_acc:{4:.4f}".format(session,
                           hist_trloss[-1], hist_trloss_qlog[-1], hist_trloss_skip[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query    = 0
                total_trloss   = 0
                total_trloss_qlog   = 0
                total_trloss_skip   = 0
            
            if (session+1)%8000 == 0:
                 # Validation
                 validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                 # Save
                 torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 
                             'hist_trloss_qlog': hist_trloss_qlog, 'hist_trloss_skip': hist_trloss_skip,  'hist_vacc': hist_vacc,
                             'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                             'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1],
                    'hist_trloss_qlog': hist_trloss_qlog, 'hist_trloss_skip': hist_trloss_skip,  'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                    'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
示例#5
0
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True)  # shuffle은 True로 해야됨 나중에...

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 124950714),  # 20%를 테스트
        batch_size=2048,
        shuffle=False)

    # Init neural net
    #FeatEnc = MLP(input_sz=29, hidden_sz=512, output_sz=64).apply(weights_init).cuda(GPU)
    FeatEnc = MLP(input_sz=29, hidden_sz=256, output_sz=64).cuda(GPU)
    RN = RelationNetwork().cuda(GPU)

    FeatEnc_optim = torch.optim.Adam(FeatEnc.parameters(), lr=LEARNING_RATE)
    RN_optim = torch.optim.Adam(RN.parameters(), lr=LEARNING_RATE)

    FeatEnc_scheduler = StepLR(FeatEnc_optim, step_size=100000, gamma=0.2)
    RN_scheduler = StepLR(RN_optim, step_size=100000, gamma=0.2)

    if args.load_continue_latest is None:
        START_EPOCH = 0

    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['hist_trloss'][-1]))
        FeatEnc.load_state_dict(checkpoint['FE_state'])
        RN.load_state_dict(checkpoint['RN_state'])
        FeatEnc_optim.load_state_dict(checkpoint['FE_opt_state'])
        RN_optim.load_state_dict(checkpoint['RN_opt_state'])
        FeatEnc_scheduler.load_state_dict(checkpoint['FE_sch_state'])
        RN_scheduler.load_state_dict(checkpoint['RN_sch_state'])
        START_EPOCH = checkpoint['ep']

    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):

        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):

            FeatEnc.train()
            RN.train()
            x_sup, x_que, x_log_sup, x_log_que, label_sup, label_que, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS
            x_sup, x_que = Variable(x_sup).cuda(GPU), Variable(x_que).cuda(GPU)
            x_log_sup, x_log_que = Variable(x_log_sup).cuda(GPU), Variable(
                x_log_que).cuda(GPU)
            label_sup = Variable(label_sup).cuda(GPU)

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            x_sup = x_sup.unsqueeze(2)  # 1x7*29 --> 1x7x1*29
            x_que = x_que.unsqueeze(2)  # 1x8*29 --> 1x8x1*29

            # - feature encoder
            x_feat_sup = FeatEnc(x_sup)  # 1x7x1*64
            x_feat_que = FeatEnc(x_que)  # 1x8x1*64

            # - relation network
            y_hat = RN(x_feat_sup, x_feat_que, x_log_sup, x_log_que,
                       label_sup)  # bx8

            # Prepare ground-truth simlarity score and mask
            y_gt = label_que[:, :, 1]
            y_mask = np.zeros((batch_sz, 10), dtype=np.float32)
            for b in np.arange(batch_sz):
                y_mask[b, :num_query[b]] = 1
            y_mask = torch.FloatTensor(y_mask).cuda(GPU)

            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(input=y_hat * y_mask,
                                                      target=y_gt.cuda(GPU) *
                                                      y_mask)
            total_trloss += loss.item()

            # Update Nets
            FeatEnc.zero_grad()
            RN.zero_grad()

            loss.backward()
            #torch.nn.utils.clip_grad_norm_(FeatEnc.parameters(), 0.5)
            #torch.nn.utils.clip_grad_norm_(RN.parameters(), 0.5)

            FeatEnc_optim.step()
            RN_optim.step()

            # Decision
            y_prob = (torch.sigmoid(y_hat) * y_mask).detach().cpu().numpy()
            y_pred = ((torch.sigmoid(y_hat) > 0.5).float() *
                      y_mask).detach().cpu().long().numpy()

            # Prepare display
            sample_sup = label_sup[0, :num_support[0],
                                   1].detach().long().cpu().numpy().flatten()
            sample_que = label_que[0, :num_query[0],
                                   1].long().numpy().flatten()
            sample_pred = y_pred[0, :num_query[0]].flatten()
            sample_prob = y_prob[0, :num_query[0]].flatten()

            # Acc
            total_corrects += np.sum(
                (y_pred == label_que[:, :, 1].long().numpy()) *
                y_mask.cpu().numpy())
            total_query += np.sum(num_query)

            # Restore GPU memory
            del loss, x_feat_sup, x_feat_que, y_hat

            if (session + 1) % 900 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))

                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 4000 == 0:
                # Validation
                validate(mval_loader, FeatEnc, RN, eval_mode=True)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'FE_state': FeatEnc.state_dict(),
                        'RN_state': RN.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'FE_opt_state': FeatEnc_optim.state_dict(),
                        'RN_opt_state': RN_optim.state_dict(),
                        'FE_sch_state': FeatEnc_scheduler.state_dict(),
                        'RN_sch_state': RN_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))

        # Validation
        validate(mval_loader, FeatEnc, RN, eval_mode=True)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'FE_state': FeatEnc.state_dict(),
                'RN_state': RN.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'FE_opt_state': FeatEnc_optim.state_dict(),
                'RN_opt_state': RN_optim.state_dict(),
                'FE_sch_state': FeatEnc_scheduler.state_dict(),
                'RN_sch_state': RN_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    # while(1):
    #     a=2
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    traindir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                            'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_hl/cv/ImageNet/',
                          'ILSVRC2012_img_val_rec')
    train_loader, val_loader = getTrainValDataset(traindir, valdir,
                                                  batch_sizes, 100, num_gpu,
                                                  num_workers)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')

    model_t = ResNet50()

    # model_kd = resnet101(pretrained=False)

    #print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t

    model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = ResNet50_sprase().to(args.gpus[0])
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    #ckpt_kd = torch.load('resnet101-5d3b4d8f.pth', map_location=torch.device(f"cuda:{args.gpus[0]}"))
    #state_dict_kd = ckpt_kd
    #new_state_dict_kd = state_dict_kd
    #model_kd.load_state_dict(new_state_dict_kd)
    #model_kd = model_kd.to(args.gpus[0])

    #for para in list(model_kd.parameters())[:-2]:
    #para.requires_grad = False

    model_d = Discriminator().to(args.gpus[0])

    model_s = nn.DataParallel(model_s).cuda()
    model_t = nn.DataParallel(model_t).cuda()
    model_d = nn.DataParallel(model_d).cuda()

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr * 100, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_s = ckpt['state_dict_s']
        state_dict_d = ckpt['state_dict_d']

        new_state_dict_s = OrderedDict()
        for k, v in state_dict_s.items():
            new_state_dict_s['module.' + k] = v

        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(new_state_dict_s)
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        start_epoch = ckpt['epoch']
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    models = [model_t, model_s, model_d]  #, model_kd]
    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        #global g_e
        #g_e = epoch
        #gl.set_value('epoch',g_e)

        train(args, train_loader, models, optimizers, epoch, writer_train)
        test_prec1, test_prec5 = test(args, val_loader, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        train_loader.reset()
        val_loader.reset()
        #if is_best:
        checkpoint.save_model(state, epoch + 1, is_best)
        #checkpoint.save_model(state, 1, False)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
示例#7
0
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    # Data loading
    print('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    # Create model
    print('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=device)

    if args.pruned:
        state_dict = checkpoint['state_dict_s']
        if args.arch == 'vgg':
            cfg = checkpoint['cfg']
            model = vgg_16_bn_sparse(cfg=cfg).to(device)
        # pruned = sum([1 for m in mask if mask == 0])
        # print(f"Pruned / Total: {pruned} / {len(mask)}")
        elif args.arch == 'resnet':
            mask = checkpoint['mask']
            model = resnet_56_sparse(has_mask=mask).to(device)

        elif args.arch == 'densenet':
            filters = checkpoint['filters']
            indexes = checkpoint['indexes']
            model = densenet_40_sparse(filters=filters,
                                       indexes=indexes).to(device)
        elif args.arch == 'googlenet':
            mask = checkpoint['mask']
            model = googlenet_sparse(has_mask=mask).to(device)
        model.load_state_dict(state_dict)
    else:
        model = import_module('utils.preprocess').__dict__[f'{args.arch}'](
            args, checkpoint['state_dict_s'])
    '''
    print_logger.info(f"Simply test after pruning...")
    test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test, 0)
    '''
    if args.test_only:
        return

    if args.keep_grad:
        for name, weight in model.named_parameters():
            if 'mask' in name:
                weight.requires_grad = False

    train_param = [
        param for name, param in model.named_parameters() if 'mask' not in name
    ]

    optimizer = optim.SGD(train_param,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=device)
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    for epoch in range(start_epoch, args.num_epochs):
        scheduler.step(epoch)

        train(args, loader.loader_train, model, criterion, optimizer,
              writer_train, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model,
                                      criterion, writer_test, epoch)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, is_best)

    print_logger.info(
        f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    # Model compression info
    flops, params = get_model_complexity_info(model.to(device), (3, 32, 32),
                                              as_strings=False,
                                              print_per_layer_stat=True)
    compressionInfo(flops, params)
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True,
        seq_mode=True)  # seq_mode implemented

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 104965071),  #(99965071, 124950714), # 20%를 테스트
        batch_size=TS_BATCH_SZ,
        shuffle=False,
        seq_mode=True)

    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.8)

    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']

    # Train
    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):
            SM.train()
            x, labels, y_mask, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x = x.permute(0, 2, 1)  # bx70*20
            x_sup = Variable(
                torch.cat((x[:, :, :10], labels[:, :10].unsqueeze(1)),
                          1)).cuda(GPU)  # bx71(41+29+1)*10
            x_que = torch.zeros(batch_sz, 72, 20)
            x_que[:, :41, :10] = x[:, :41, :10].clone()  # fill with x_sup_log
            x_que[:, 41:70, :] = x[:, 41:, :].clone(
            )  # fill with x_sup_feat and x_que_feat
            x_que[:, 70, :10] = 1  # support marking
            x_que[:, 71, :10] = labels[:, :10]  # labels marking
            x_que = Variable(x_que).cuda(GPU)  # bx29*10

            # y
            y = labels.clone()  # bx20

            # y_mask
            y_mask_que = y_mask.clone()
            y_mask_que[:, :10] = 0

            # Forward & update
            y_hat, att = SM(x_sup, x_que)  # y_hat: b*20, att: bx10*20

            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(
                input=y_hat * y_mask_que.cuda(GPU),
                target=y.cuda(GPU) * y_mask_que.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()

            # Decision
            y_prob = torch.sigmoid(
                y_hat * y_mask_que.cuda(GPU)).detach().cpu().numpy()  # bx20
            y_pred = (y_prob[:, 10:] > 0.5).astype(np.int)  # bx10
            y_numpy = labels[:, 10:].numpy()  # bx10
            # Acc
            total_corrects += np.sum(
                (y_pred == y_numpy) * y_mask_que[:, 10:].numpy())
            total_query += np.sum(num_query)

            # Restore GPU memory
            del loss, y_hat

            if (session + 1) % 500 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                # Prepare display
                sample_att = att[0, (10 - num_support[0]):10,
                                 (10 - num_support[0]):(
                                     10 +
                                     num_query[0])].detach().cpu().numpy()

                sample_sup = labels[0, (
                    10 - num_support[0]):10].long().numpy().flatten()
                sample_que = y_numpy[0, :num_query[0]].astype(int)
                sample_pred = y_pred[0, :num_query[0]]
                sample_prob = y_prob[0, 10:10 + num_query[0]]

                tqdm.write(
                    np.array2string(sample_att,
                                    formatter={
                                        'float_kind':
                                        lambda sample_att: "%.2f" % sample_att
                                    }).replace('\n ', '').replace(
                                        '][', ']\n[').replace('[[', '['))
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))
                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 25000 == 0:
                # Validation
                validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'SM_state': SM.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'SM_opt_state': SM_optim.state_dict(),
                        'SM_sch_state': SM_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'SM_state': SM.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'SM_opt_state': SM_optim.state_dict(),
                'SM_sch_state': SM_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
示例#9
0
    with torch.no_grad():
        d.clamp_(bounds[0], bounds[1])

    # Save designs to list
    designs.append(d.clone().detach())

# --- SAVE DATA --- #

# clean up lists
train_loss = np.array([i.cpu().data.numpy() for i in train_loss])
train_loss_viaNWJ = np.array([i.cpu().data.numpy() for i in train_loss_viaNWJ])
designs = np.array([dd.cpu().tolist() for dd in designs])

# create save_dict
save_dict = dict()
save_dict['seed'] = SEED
save_dict['modelparams_jsd'] = modelparams
save_dict['d_init'] = d_init
save_dict['model_init_state'] = model_init_state
save_dict['designs_train_jsd'] = designs
save_dict['model_jsd'] = model.state_dict()
save_dict['train_loss_jsd'] = train_loss
save_dict['train_loss_jsd_viaNWJ'] = train_loss_viaNWJ
save_dict['optimizer_psi_state'] = optimizer_psi.state_dict()
save_dict['optimizer_design_state'] = optimizer_design.state_dict()
save_dict['scheduler_psi_state'] = scheduler_psi.state_dict()
save_dict['scheduler_design_state'] = scheduler_design.state_dict()

# save data
torch.save(save_dict, FILENAME)
def train():

    parser = argparse.ArgumentParser(
        description='PyTorch Medical Segmentation Training')
    parser = parse_training_args(parser)
    args, _ = parser.parse_known_args()

    args = parser.parse_args()

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    from data_function import MedData_train
    os.makedirs(args.output_dir, exist_ok=True)

    if hp.mode == '2d':
        from models.two_d.unet import Unet
        model = Unet(in_channels=hp.in_class, classes=hp.out_class)

        # from models.two_d.miniseg import MiniSeg
        # model = MiniSeg(in_input=hp.in_class, classes=hp.out_class)

        # from models.two_d.fcn import FCN32s as fcn
        # model = fcn(in_class =hp.in_class,n_class=hp.out_class)

        # from models.two_d.segnet import SegNet
        # model = SegNet(input_nbr=hp.in_class,label_nbr=hp.out_class)

        # from models.two_d.deeplab import DeepLabV3
        # model = DeepLabV3(in_class=hp.in_class,class_num=hp.out_class)

        # from models.two_d.unetpp import ResNet34UnetPlus
        # model = ResNet34UnetPlus(num_channels=hp.in_class,num_class=hp.out_class)

        # from models.two_d.pspnet import PSPNet
        # model = PSPNet(in_class=hp.in_class,n_classes=hp.out_class)

    elif hp.mode == '3d':

        from models.three_d.unet3d import UNet3D
        model = UNet3D(in_channels=hp.in_class,
                       out_channels=hp.out_class,
                       init_features=32)

        # from models.three_d.residual_unet3d import UNet
        # model = UNet(in_channels=hp.in_class, n_classes=hp.out_class, base_n_filter=2)

        #from models.three_d.fcn3d import FCN_Net
        #model = FCN_Net(in_channels =hp.in_class,n_class =hp.out_class)

        #from models.three_d.highresnet import HighRes3DNet
        #model = HighRes3DNet(in_channels=hp.in_class,out_channels=hp.out_class)

        #from models.three_d.densenet3d import SkipDenseNet3D
        #model = SkipDenseNet3D(in_channels=hp.in_class, classes=hp.out_class)

        # from models.three_d.densevoxelnet3d import DenseVoxelNet
        # model = DenseVoxelNet(in_channels=hp.in_class, classes=hp.out_class)

        #from models.three_d.vnet3d import VNet
        #model = VNet(in_channels=hp.in_class, classes=hp.out_class)

    model = torch.nn.DataParallel(model, device_ids=devicess)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr)

    # scheduler = ReduceLROnPlateau(optimizer, 'min',factor=0.5, patience=20, verbose=True)
    scheduler = StepLR(optimizer,
                       step_size=hp.scheduer_step_size,
                       gamma=hp.scheduer_gamma)
    # scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=5e-6)

    if args.ckpt is not None:
        print("load model:", args.ckpt)
        print(os.path.join(args.output_dir, args.latest_checkpoint_file))
        ckpt = torch.load(os.path.join(args.output_dir,
                                       args.latest_checkpoint_file),
                          map_location=lambda storage, loc: storage)

        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optim"])

        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

        # scheduler.load_state_dict(ckpt["scheduler"])
        elapsed_epochs = ckpt["epoch"]
    else:
        elapsed_epochs = 0

    model.cuda()

    from loss_function import Binary_Loss, DiceLoss
    criterion = Binary_Loss().cuda()

    writer = SummaryWriter(args.output_dir)

    train_dataset = MedData_train(source_train_dir, label_train_dir)
    train_loader = DataLoader(train_dataset.queue_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    model.train()

    epochs = args.epochs - elapsed_epochs
    iteration = elapsed_epochs * len(train_loader)

    for epoch in range(1, epochs + 1):
        print("epoch:" + str(epoch))
        epoch += elapsed_epochs

        num_iters = 0

        for i, batch in enumerate(train_loader):

            if hp.debug:
                if i >= 1:
                    break

            print(f"Batch: {i}/{len(train_loader)} epoch {epoch}")

            optimizer.zero_grad()

            if (hp.in_class == 1) and (hp.out_class == 1):
                x = batch['source']['data']
                y = batch['label']['data']

                x = x.type(torch.FloatTensor).cuda()
                y = y.type(torch.FloatTensor).cuda()

            else:
                x = batch['source']['data']
                y_atery = batch['atery']['data']
                y_lung = batch['lung']['data']
                y_trachea = batch['trachea']['data']
                y_vein = batch['atery']['data']

                x = x.type(torch.FloatTensor).cuda()

                y = torch.cat((y_atery, y_lung, y_trachea, y_vein), 1)
                y = y.type(torch.FloatTensor).cuda()

            if hp.mode == '2d':
                x = x.squeeze(4)
                y = y.squeeze(4)

                y[y != 0] = 1

            # print(y.max())

            outputs = model(x)

            # for metrics
            logits = torch.sigmoid(outputs)
            labels = logits.clone()
            labels[labels > 0.5] = 1
            labels[labels <= 0.5] = 0

            loss = criterion(outputs, y)

            num_iters += 1
            loss.backward()

            optimizer.step()
            iteration += 1

            false_positive_rate, false_negtive_rate, dice = metric(
                y.cpu(), labels.cpu())
            ## log
            writer.add_scalar('Training/Loss', loss.item(), iteration)
            writer.add_scalar('Training/false_positive_rate',
                              false_positive_rate, iteration)
            writer.add_scalar('Training/false_negtive_rate',
                              false_negtive_rate, iteration)
            writer.add_scalar('Training/dice', dice, iteration)

            print("loss:" + str(loss.item()))
            print('lr:' + str(scheduler._last_lr[0]))

        scheduler.step()

        # Store latest checkpoint in each epoch
        torch.save(
            {
                "model": model.state_dict(),
                "optim": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "epoch": epoch,
            },
            os.path.join(args.output_dir, args.latest_checkpoint_file),
        )

        # Save checkpoint
        if epoch % args.epochs_per_checkpoint == 0:

            torch.save(
                {
                    "model": model.state_dict(),
                    "optim": optimizer.state_dict(),
                    "epoch": epoch,
                },
                os.path.join(args.output_dir, f"checkpoint_{epoch:04d}.pt"),
            )

            with torch.no_grad():
                if hp.mode == '2d':
                    x = x.unsqueeze(4)
                    y = y.unsqueeze(4)
                    outputs = outputs.unsqueeze(4)

                x = x[0].cpu().detach().numpy()
                y = y[0].cpu().detach().numpy()
                outputs = outputs[0].cpu().detach().numpy()
                affine = batch['source']['affine'][0].numpy()

                if (hp.in_class == 1) and (hp.out_class == 1):
                    source_image = torchio.ScalarImage(tensor=x, affine=affine)
                    source_image.save(
                        os.path.join(args.output_dir,
                                     f"step-{epoch:04d}-source" +
                                     hp.save_arch))
                    # source_image.save(os.path.join(args.output_dir,("step-{}-source.mhd").format(epoch)))

                    label_image = torchio.ScalarImage(tensor=y, affine=affine)
                    label_image.save(
                        os.path.join(args.output_dir,
                                     f"step-{epoch:04d}-gt" + hp.save_arch))

                    output_image = torchio.ScalarImage(tensor=outputs,
                                                       affine=affine)
                    output_image.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict" + hp.save_arch))
                else:
                    y = np.expand_dims(y, axis=1)
                    outputs = np.expand_dims(outputs, axis=1)

                    source_image = torchio.ScalarImage(tensor=x, affine=affine)
                    source_image.save(
                        os.path.join(args.output_dir,
                                     f"step-{epoch:04d}-source" +
                                     hp.save_arch))

                    label_image_artery = torchio.ScalarImage(tensor=y[0],
                                                             affine=affine)
                    label_image_artery.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_artery" + hp.save_arch))

                    output_image_artery = torchio.ScalarImage(
                        tensor=outputs[0], affine=affine)
                    output_image_artery.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_artery" + hp.save_arch))

                    label_image_lung = torchio.ScalarImage(tensor=y[1],
                                                           affine=affine)
                    label_image_lung.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_lung" + hp.save_arch))

                    output_image_lung = torchio.ScalarImage(tensor=outputs[1],
                                                            affine=affine)
                    output_image_lung.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_lung" + hp.save_arch))

                    label_image_trachea = torchio.ScalarImage(tensor=y[2],
                                                              affine=affine)
                    label_image_trachea.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_trachea" + hp.save_arch))

                    output_image_trachea = torchio.ScalarImage(
                        tensor=outputs[2], affine=affine)
                    output_image_trachea.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_trachea" +
                            hp.save_arch))

                    label_image_vein = torchio.ScalarImage(tensor=y[3],
                                                           affine=affine)
                    label_image_vein.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_vein" + hp.save_arch))

                    output_image_vein = torchio.ScalarImage(tensor=outputs[3],
                                                            affine=affine)
                    output_image_vein.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_vein" + hp.save_arch))

    writer.close()
示例#11
0
def main_mlp():
    # Training settings
    parser = argparse.ArgumentParser(
        description="GNN baselines on ogbgmol* data with Pytorch Geometrics")
    parser.add_argument("--device",
                        type=int,
                        default=0,
                        help="which gpu to use if any (default: 0)")
    parser.add_argument(
        "--num_mlp_layers",
        type=int,
        default=6,
        help="number of mlp layers (default: 6)",
    )
    parser.add_argument("--drop_ratio",
                        type=float,
                        default=0.2,
                        help="dropout ratio (default: 0.2)")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="input batch size for training (default: 256)",
    )
    parser.add_argument(
        "--emb_dim",
        type=int,
        default=1600,
        help="embedding dimensionality (default: 1600)",
    )
    parser.add_argument("--train_subset", action="store_true")
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument("--num_workers",
                        type=int,
                        default=0,
                        help="number of workers (default: 0)")
    parser.add_argument("--radius",
                        type=int,
                        default=2,
                        help="radius (default: 2)")
    parser.add_argument("--log_dir",
                        type=str,
                        default="",
                        help="tensorboard log directory")
    parser.add_argument("--checkpoint_dir",
                        type=str,
                        default="",
                        help="directory to save checkpoint")
    parser.add_argument(
        "--save_test_dir",
        type=str,
        default="",
        help="directory to save test submission file",
    )
    args = parser.parse_args()

    print(args)

    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    random.seed(42)

    device = (torch.device("cuda:" + str(args.device))
              if torch.cuda.is_available() else torch.device("cpu"))

    dataset = PCQM4MDataset(root="dataset/", only_smiles=True)
    fp_processed_file = preprocess_fp(dataset, args.radius)

    data_dict = torch.load(fp_processed_file)
    X, Y = data_dict["X"], data_dict["Y"]

    split_idx = dataset.get_idx_split()
    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        print("train subset")
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_dataset = TensorDataset(X[split_idx["train"][subset_idx]],
                                      Y[split_idx["train"][subset_idx]])

    else:
        train_dataset = TensorDataset(X[split_idx["train"]],
                                      Y[split_idx["train"]])

    valid_dataset = TensorDataset(X[split_idx["valid"]], Y[split_idx["valid"]])
    test_dataset = TensorDataset(X[split_idx["test"]], Y[split_idx["test"]])

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    if args.save_test_dir is not "":
        test_loader = DataLoader(
            test_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )

    if args.checkpoint_dir is not "":
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    model = MLP(
        num_mlp_layers=args.num_mlp_layers,
        emb_dim=args.emb_dim,
        drop_ratio=args.drop_ratio,
    ).to(device)

    num_params = sum(p.numel() for p in model.parameters())
    print(f"#Params: {num_params}")

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir is not "":
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print("Training...")
        train_mae = train(model, device, train_loader, optimizer)

        print("Evaluating...")
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({"Train": train_mae, "Validation": valid_mae})

        if args.log_dir is not "":
            writer.add_scalar("valid/mae", valid_mae, epoch)
            writer.add_scalar("train/mae", train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not "":
                print("Saving checkpoint...")
                checkpoint = {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "best_val_mae": best_valid_mae,
                    "num_params": num_params,
                }
                torch.save(checkpoint,
                           osp.join(args.checkpoint_dir, "checkpoint.pt"))

            if args.save_test_dir is not "":
                print("Predicting on test data...")
                y_pred = test(model, device, test_loader)
                print("Saving test submission file...")
                evaluator.save_test_submission({"y_pred": y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f"Best validation MAE so far: {best_valid_mae}")

    if args.log_dir is not "":
        writer.close()
示例#12
0
def main(args):
    prepartion(args)
    nn_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    # automatic dataloading and splitting
    dataset = MyPCQM4MDataset(root=args.dataset_root)
    split_idx = dataset.get_idx_split()
    train_data = dataset[split_idx['train']]
    valid_data = dataset[split_idx['valid']]
    test_data = dataset[split_idx['test']]
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_loader = DataLoader(valid_data,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)
    test_loader = DataLoader(test_data,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    # automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()
    criterion_fn = torch.nn.MSELoss()

    device = args.device

    model = GINGraphPooling(**nn_params).to(device)

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}', file=args.output_file, flush=True)
    print(model, file=args.output_file, flush=True)

    optimizer = optim.Adam(model.parameters(),
                           lr=0.001,
                           weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    writer = SummaryWriter(log_dir=args.save_dir)
    not_improved = 0
    best_valid_mae = 9999
    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch), file=args.output_file, flush=True)
        print('Training...', file=args.output_file, flush=True)
        train_mae = train(model, device, train_loader, optimizer, criterion_fn)

        print('Evaluating...', file=args.output_file, flush=True)
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({
            'Train': train_mae,
            'Validation': valid_mae
        },
              file=args.output_file,
              flush=True)

        writer.add_scalar('valid/mae', valid_mae, epoch)
        writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.save_test:
                print('Saving checkpoint...',
                      file=args.output_file,
                      flush=True)
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           os.path.join(args.save_dir, 'checkpoint.pt'))
                print('Predicting on test data...',
                      file=args.output_file,
                      flush=True)
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...',
                      file=args.output_file,
                      flush=True)
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_dir)

            not_improved = 0
        else:
            not_improved += 1
            if not_improved == args.early_stop:
                print(f"Have not improved for {not_improved} epoches.",
                      file=args.output_file,
                      flush=True)
                break

        scheduler.step()
        print(f'Best validation MAE so far: {best_valid_mae}',
              file=args.output_file,
              flush=True)

    writer.close()
    args.output_file.close()
示例#13
0
    val_rouge, _ = evaluation.Rouge().compute_score(gts, gen)
    print('ROUGE_L', val_rouge)

    val_cider, _ = evaluation.Cider().compute_score(gts, gen)
    print('CIDEr', val_cider)

    saved_data = {
        'epoch': e,
        'opt': opt,
        'val_cider': val_cider,
        'patience': patience,
        'best_cider': best_cider,
        'state_dict': model.state_dict(),
        'optimizer': optim.state_dict(),
        'scheduler': scheduler.state_dict(),
    }

    if not os.path.exists('saved_models/'):
        os.makedirs('saved_models/')

    if val_cider >= best_cider:
        best_cider = val_cider
        best_srt = 'best_rl' if opt.sample_rl else 'best'
        best_srt = 'best_rl_nw' if opt.sample_rl_nw else best_srt
        patience = 0
        saved_data['best_cider'] = best_cider
        saved_data['patience'] = patience
        torch.save(saved_data, 'saved_models/%s_%s.pth' % (opt.exp_name, best_srt))
    else:
        patience += 1
示例#14
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description="GNN baselines on pcqm4m with Pytorch Geometrics")
    parser.add_argument("--device",
                        type=int,
                        default=0,
                        help="which gpu to use if any (default: 0)")
    parser.add_argument(
        "--gnn",
        type=str,
        default="gin-virtual",
        help=
        "GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)",
    )
    parser.add_argument(
        "--graph_pooling",
        type=str,
        default="sum",
        help="graph pooling strategy mean or sum (default: sum)",
    )
    parser.add_argument("--drop_ratio",
                        type=float,
                        default=0,
                        help="dropout ratio (default: 0)")
    parser.add_argument(
        "--num_layers",
        type=int,
        default=5,
        help="number of GNN message passing layers (default: 5)",
    )
    parser.add_argument(
        "--emb_dim",
        type=int,
        default=600,
        help="dimensionality of hidden units in GNNs (default: 600)",
    )
    parser.add_argument("--train_subset", action="store_true")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="input batch size for training (default: 256)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument("--num_workers",
                        type=int,
                        default=0,
                        help="number of workers (default: 0)")
    parser.add_argument("--log_dir",
                        type=str,
                        default="",
                        help="tensorboard log directory")
    parser.add_argument("--checkpoint_dir",
                        type=str,
                        default="",
                        help="directory to save checkpoint")
    parser.add_argument(
        "--save_test_dir",
        type=str,
        default="",
        help="directory to save test submission file",
    )
    args = parser.parse_args()

    print(args)

    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    random.seed(42)

    device = (torch.device("cuda:" + str(args.device))
              if torch.cuda.is_available() else torch.device("cpu"))

    ### automatic dataloading and splitting
    dataset = PygPCQM4MDataset(root="dataset/")

    split_idx = dataset.get_idx_split()

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_loader = DataLoader(
            dataset[split_idx["train"][subset_idx]],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
    else:
        train_loader = DataLoader(
            dataset[split_idx["train"]],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )

    valid_loader = DataLoader(
        dataset[split_idx["valid"]],
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    if args.save_test_dir is not "":
        test_loader = DataLoader(
            dataset[split_idx["test"]],
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )

    if args.checkpoint_dir is not "":
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        "num_layers": args.num_layers,
        "emb_dim": args.emb_dim,
        "drop_ratio": args.drop_ratio,
        "graph_pooling": args.graph_pooling,
    }

    if args.gnn == "gin":
        model = GNN(gnn_type="gin", virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == "gin-virtual":
        model = GNN(gnn_type="gin", virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == "gcn":
        model = GNN(gnn_type="gcn", virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == "gcn-virtual":
        model = GNN(gnn_type="gcn", virtual_node=True,
                    **shared_params).to(device)
    else:
        raise ValueError("Invalid GNN type")

    num_params = sum(p.numel() for p in model.parameters())
    print(f"#Params: {num_params}")

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir is not "":
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print("Training...")
        train_mae = train(model, device, train_loader, optimizer)

        print("Evaluating...")
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({"Train": train_mae, "Validation": valid_mae})

        if args.log_dir is not "":
            writer.add_scalar("valid/mae", valid_mae, epoch)
            writer.add_scalar("train/mae", train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not "":
                print("Saving checkpoint...")
                checkpoint = {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "best_val_mae": best_valid_mae,
                    "num_params": num_params,
                }
                torch.save(checkpoint,
                           os.path.join(args.checkpoint_dir, "checkpoint.pt"))

            if args.save_test_dir is not "":
                print("Predicting on test data...")
                y_pred = test(model, device, test_loader)
                print("Saving test submission file...")
                evaluator.save_test_submission({"y_pred": y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f"Best validation MAE so far: {best_valid_mae}")

    if args.log_dir is not "":
        writer.close()
示例#15
0
            tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                       np.array2string(sample_que) + '\n' + "P:" +
                       np.array2string(sample_pred))
            tqdm.write(
                "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                    session, hist_trloss[-1], hist_tracc[-1]))
            total_corrects = 0
            total_query = 0
            total_trloss = 0

        if (session + 1) % 40000 == 0:
            # Validation
            validate()
            # Save
            torch.save(
                {
                    'ep': epoch,
                    'sess': session,
                    'FE_state': FeatEnc.state_dict(),
                    'RN_state': RN.state_dict(),
                    'loss': None,
                    'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss,
                    'hist_trloss': hist_trloss,
                    'FE_opt_state': FeatEnc_optim.state_dict(),
                    'RN_opt_state': RN_optim.state_dict(),
                    'FE_sch_state': FeatEnc_scheduler.state_dict(),
                    'RN_sch_state': RN_scheduler.state_dict()
                },
                MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
示例#16
0
            # Save Model
            if val_accuracy > max_accuracy:
                # Prepare folder
                folder_for_this_accuracy = os.path.join(output_folder, str(val_accuracy))
                max_accuracy = val_accuracy
                print("Models Saved with accuracy={}".format(max_accuracy))
            else:
                folder_for_this_accuracy = os.path.join(output_folder, "Latest_{}".format(val_accuracy))

            if not os.path.exists(folder_for_this_accuracy):
                os.mkdir(folder_for_this_accuracy)

            # Save networks
            torch.save(c3d.state_dict(), os.path.join(folder_for_this_accuracy, "c3d.pkl"))
            torch.save(rn.state_dict(), os.path.join(folder_for_this_accuracy, "rn.pkl"))
            torch.save(tcn.state_dict(), os.path.join(folder_for_this_accuracy, "tcn.pkl"))
            torch.save(ap.state_dict(), os.path.join(folder_for_this_accuracy, "ap.pkl"))

            torch.save(c3d_optim.state_dict(), os.path.join(folder_for_this_accuracy, "c3d_optim.pkl"))
            torch.save(rn_optim.state_dict(), os.path.join(folder_for_this_accuracy, "rn_optim.pkl"))
            torch.save(tcn_optim.state_dict(), os.path.join(folder_for_this_accuracy, "tcn_optim.pkl"))
            torch.save(ap_optim.state_dict(), os.path.join(folder_for_this_accuracy, "ap_optim.pkl"))

            torch.save(c3d_scheduler.state_dict(), os.path.join(folder_for_this_accuracy, "c3d_scheduler.pkl"))
            torch.save(rn_scheduler.state_dict(), os.path.join(folder_for_this_accuracy, "rn_scheduler.pkl"))
            torch.save(tcn_scheduler.state_dict(), os.path.join(folder_for_this_accuracy, "tcn_scheduler.pkl"))
            torch.save(ap_scheduler.state_dict(), os.path.join(folder_for_this_accuracy, "ap_scheduler.pkl"))

print("Training Done")
print("Final Accuracy = {}".format(max_accuracy))
示例#17
0
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    # Create model
    print('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        
    if args.pruned:
        mask = checkpoint['mask']
        pruned = sum([1 for m in mask if mask == 0])
        print(f"Pruned / Total: {pruned} / {len(mask)}")
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test)
    print(f"Simply test after prune {test_prec1:.3f}")
    
    if args.test_only:
        return 

    if args.keep_grad:
        for name, weight in model.named_parameters():
            if 'mask' in name:
                weight.requires_grad = False

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    for epoch in range(start_epoch, args.num_epochs):
        scheduler.step(epoch)

        train(args, loader.loader_train, model, criterion, optimizer, writer_train, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test, epoch)

        is_best_finetune = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, False, is_best_finetune)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
示例#18
0
文件: main_gnn.py 项目: rpatil524/ogb
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    random.seed(42)

    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = PygPCQM4MDataset(root='dataset/')

    split_idx = dataset.get_idx_split()

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_loader = DataLoader(dataset[split_idx["train"][subset_idx]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    else:
        train_loader = DataLoader(dataset[split_idx["train"]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)

    if args.save_test_dir != '':
        test_loader = DataLoader(dataset[split_idx["test-dev"]],
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    if args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True,
                    **shared_params).to(device)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir != '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, device, train_loader, optimizer)

        print('Evaluating...')
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir != '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir != '':
                print('Saving checkpoint...')
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           os.path.join(args.checkpoint_dir, 'checkpoint.pt'))

            if args.save_test_dir != '':
                print('Predicting on test data...')
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir,
                                               mode='test-dev')

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

    if args.log_dir != '':
        writer.close()
示例#19
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with DGL')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='random seed to use (default: 42)')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help='GNN to use, which can be from '
        '[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset',
                        action='store_true',
                        help='use 10% of the training set for training')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory. If not specified, '
                        'tensorboard will not be used.')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = SampleDglPCQM4MDataset(root='dataset/')

    # split_idx['train'], split_idx['valid'], split_idx['test']
    # separately gives a 1D int64 tensor
    split_idx = dataset.get_idx_split()
    split_idx["train"] = split_idx["train"].type(torch.LongTensor)
    split_idx["test"] = split_idx["test"].type(torch.LongTensor)
    split_idx["valid"] = split_idx["valid"].type(torch.LongTensor)

    ### automatic evaluator.
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_loader = DataLoader(dataset[split_idx["train"][subset_idx]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_dgl)
    else:
        train_loader = DataLoader(dataset[split_idx["train"]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_dgl)

    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=collate_dgl)

    if args.save_test_dir is not '':
        test_loader = DataLoader(dataset[split_idx["test"]],
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_dgl)

    if args.checkpoint_dir is not '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual-diffpool':
        model = DiffPoolGNN(gnn_type='gin', virtual_node=True,
                            **shared_params).to(device)
    elif args.gnn == 'gin-virtual-bayes-diffpool':
        model = BayesDiffPoolGNN(gnn_type='gin',
                                 virtual_node=True,
                                 **shared_params).to(device)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir is not '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)
    """ load from latest checkpoint """
    # start epoch (default = 1, unless resuming training)
    firstEpoch = 1
    # check if checkpoint exist -> load model
    checkpointFile = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
    if os.path.exists(checkpointFile):
        # load checkpoint file
        checkpointData = torch.load(checkpointFile)
        firstEpoch = checkpointData["epoch"]
        model.load_state_dict(checkpointData["model_state_dict"])
        optimizer.load_state_dict(checkpointData["optimizer_state_dict"])
        scheduler.load_state_dict(checkpointData["scheduler_state_dict"])
        best_valid_mae = checkpointData["best_val_mae"]
        num_params = checkpointData["num_params"]
        print(
            "Loaded existing weights from {}. Continuing from epoch: {} with best valid MAE: {}"
            .format(checkpointFile, firstEpoch, best_valid_mae))

    for epoch in range(firstEpoch, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, device, train_loader, optimizer, args.gnn)

        print('Evaluating...')
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir is not '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not '':
                print('Saving checkpoint...')
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           os.path.join(args.checkpoint_dir, 'checkpoint.pt'))

            if args.save_test_dir is not '':
                print('Predicting on test data...')
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

    if args.log_dir is not '':
        writer.close()
示例#20
0
class DeployedESTransformer(object):
    def __init__(
            self,
            max_epochs=15,
            batch_size=1,
            batch_size_test=64,
            freq_of_test=-1,
            learning_rate=1e-3,
            lr_scheduler_step_size=9,
            lr_decay=0.9,
            per_series_lr_multip=1.0,
            gradient_eps=1e-8,
            transformer_weight_decay=0,
            noise_std=0.001,
            level_variability_penalty=80,
            testing_percentile=50,
            training_percentile=50,
            ensemble=False,
            seasonality=[4],
            input_size=4,
            output_size=8,
            frequency=None,
            max_periods=20,
            random_seed=1,
            device='cpu',
            root_dir='./',

            # Transformer parameters
            d_input=4,
            d_model=48,
            d_output=6,
            q=8,
            v=8,
            h=4,
            N=4,
            attention_size=None,
            dropout=0.3,
            chunk_mode='chunk',
            pe=None,
            pe_period=24,
            dataset_name=None):

        super().__init__()
        self.mc = ModelConfig(
            max_epochs=max_epochs,
            batch_size=batch_size,
            batch_size_test=batch_size_test,
            freq_of_test=freq_of_test,
            learning_rate=learning_rate,
            lr_scheduler_step_size=lr_scheduler_step_size,
            lr_decay=lr_decay,
            per_series_lr_multip=per_series_lr_multip,
            gradient_eps=gradient_eps,
            transformer_weight_decay=transformer_weight_decay,
            noise_std=noise_std,
            level_variability_penalty=level_variability_penalty,
            testing_percentile=testing_percentile,
            training_percentile=training_percentile,
            ensemble=ensemble,
            seasonality=seasonality,
            input_size=input_size,
            output_size=output_size,
            frequency=frequency,
            max_periods=max_periods,
            random_seed=random_seed,
            device=device,
            root_dir=root_dir,
            d_input=d_input,
            d_model=d_model,
            d_output=d_output,
            q=q,
            v=v,
            h=h,
            N=N,
            attention_size=attention_size,
            dropout=dropout,
            chunk_mode=chunk_mode,
            pe=pe,
            pe_period=pe_period)
        self.device = device
        self.dataset_name = dataset_name
        self._fitted = False

    def instantiate_estransformer(self, exogenous_size, n_series):
        self.mc.exogenous_size = exogenous_size
        self.mc.n_series = n_series
        self.estransformer = ESTransformer(self.mc).to(self.mc.device)

    def fit(self,
            X_df,
            y_df,
            X_test_df=None,
            y_test_df=None,
            y_hat_benchmark='y_hat_naive2',
            warm_start=False,
            shuffle=True,
            verbose=True):
        # Transform long dfs to wide numpy
        assert type(X_df) == pd.core.frame.DataFrame
        assert type(y_df) == pd.core.frame.DataFrame
        assert all([(col in X_df) for col in ['unique_id', 'ds', 'x']])
        assert all([(col in y_df) for col in ['unique_id', 'ds', 'y']])
        if y_test_df is not None:
            assert y_hat_benchmark in y_test_df.columns, 'benchmark is not present in y_test_df, use y_hat_benchmark to define it'

        # Storing dfs for OWA evaluation, initializing min_owa
        self.y_train_df = y_df
        self.X_test_df = X_test_df
        self.y_test_df = y_test_df
        self.min_owa = 4.0
        self.min_epoch = 0

        self.int_ds = isinstance(self.y_train_df['ds'][0],
                                 (int, np.int, np.int64))

        self.y_hat_benchmark = y_hat_benchmark

        X, y = self.long_to_wide(X_df, y_df)
        assert len(X) == len(y)
        assert X.shape[1] >= 3

        # Exogenous variables
        unique_categories = np.unique(X[:, 1])
        self.mc.category_to_idx = dict(
            (word, index) for index, word in enumerate(unique_categories))
        exogenous_size = len(unique_categories)

        # Create batches (device in mc)
        self.train_dataloader = Iterator(mc=self.mc, X=X, y=y)

        # Random Seeds (model initialization)
        torch.manual_seed(self.mc.random_seed)
        np.random.seed(self.mc.random_seed)

        # Initialize model
        n_series = self.train_dataloader.n_series
        self.instantiate_estransformer(exogenous_size, n_series)

        # Validating frequencies
        X_train_frequency = pd.infer_freq(X_df.head()['ds'])
        y_train_frequency = pd.infer_freq(y_df.head()['ds'])
        self.frequencies = [X_train_frequency, y_train_frequency]

        if (X_test_df is not None) and (y_test_df is not None):
            X_test_frequency = pd.infer_freq(X_test_df.head()['ds'])
            y_test_frequency = pd.infer_freq(y_test_df.head()['ds'])
            self.frequencies += [X_test_frequency, y_test_frequency]

        assert len(set(self.frequencies)) <= 1, \
          "Match the frequencies of the dataframes {}".format(self.frequencies)

        self.mc.frequency = self.frequencies[0]
        print("Infered frequency: {}".format(self.mc.frequency))

        # Train model
        self._fitted = True
        self.train(dataloader=self.train_dataloader,
                   max_epochs=self.mc.max_epochs,
                   warm_start=warm_start,
                   shuffle=shuffle,
                   verbose=verbose)

    def train(self,
              dataloader,
              max_epochs,
              warm_start=False,
              shuffle=True,
              verbose=True):

        if self.mc.ensemble:
            self.estransformer_ensemble = [
                deepcopy(self.estransformer).to(self.mc.device)
            ] * 5
        if verbose:
            print(15 * '=' + ' Training ESTransformer  ' + 15 * '=' + '\n')

        # Model parameters
        es_parameters = filter(lambda p: p.requires_grad,
                               self.estransformer.es.parameters())
        params = sum([np.prod(p.size()) for p in es_parameters])
        print('Number of parameters of ES: ', params)

        trans_parameters = filter(lambda p: p.requires_grad,
                                  self.estransformer.transformer.parameters())
        params = sum([np.prod(p.size()) for p in trans_parameters])
        print('Number of parameters of Transformer: ', params)

        # Optimizers
        if not warm_start:
            self.es_optimizer = optim.Adam(
                params=self.estransformer.es.parameters(),
                lr=self.mc.learning_rate * self.mc.per_series_lr_multip,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps)

            self.es_scheduler = StepLR(
                optimizer=self.es_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=0.9)

            self.transformer_optimizer = optim.Adam(
                params=self.estransformer.transformer.parameters(),
                lr=self.mc.learning_rate,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps,
                weight_decay=self.mc.transformer_weight_decay)

            self.transformer_scheduler = StepLR(
                optimizer=self.transformer_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=self.mc.lr_decay)

        all_epoch = []
        all_train_loss = []
        all_test_loss = []

        # Loss Functions
        train_tau = self.mc.training_percentile / 100
        train_loss = SmylLoss(
            tau=train_tau,
            level_variability_penalty=self.mc.level_variability_penalty)

        eval_tau = self.mc.testing_percentile / 100
        eval_loss = PinballLoss(tau=eval_tau)

        for epoch in range(max_epochs):
            self.estransformer.train()
            start = time.time()
            if shuffle: dataloader.shuffle_dataset(random_seed=epoch)
            losses = []
            for j in range(dataloader.n_batches):
                self.es_optimizer.zero_grad()
                self.transformer_optimizer.zero_grad()

                batch = dataloader.get_batch()
                windows_y, windows_y_hat, levels = self.estransformer(batch)

                # Pinball loss on normalized values
                loss = train_loss(windows_y, windows_y_hat, levels)
                losses.append(loss.data.cpu().numpy())
                loss.backward()
                self.transformer_optimizer.step()
                self.es_optimizer.step()

            # Decay learning rate
            self.es_scheduler.step()
            self.transformer_scheduler.step()

            if self.mc.ensemble:
                copy_estransformer = deepcopy(self.estransformer)
                copy_estransformer.eval()
                self.estransformer_ensemble.pop(0)
                self.estransformer_ensemble.append(copy_estransformer)

            # Evaluation
            self.train_loss = np.mean(losses)
            if verbose:
                print("========= Epoch {} finished =========".format(epoch))
                print("Training time: {}".format(round(time.time() - start,
                                                       5)))
                print("Training loss ({} prc): {:.5f}".format(
                    self.mc.training_percentile, self.train_loss))
                self.test_loss = self.model_evaluation(dataloader, eval_loss)
                print("Testing loss  ({} prc): {:.5f}".format(
                    self.mc.testing_percentile, self.test_loss))
                self.evaluate_model_prediction(self.y_train_df,
                                               self.X_test_df,
                                               self.y_test_df,
                                               self.y_hat_benchmark,
                                               epoch=epoch)
                self.estransformer.train()

                all_epoch.append(epoch)
                all_train_loss.append(self.train_loss)
                all_test_loss.append(self.test_loss)

                converge = pd.DataFrame({
                    'Epoch': all_epoch,
                    'Train loss': all_train_loss,
                    'Test loss': all_test_loss
                })
                # converge.to_csv("D:\\Sang\\hybcast\\hybcast3\\" + self.dataset_name + 'log_' + self.dataset_name +'.csv', index=False)

            if (epoch % 100 == 0) or (epoch % 499 == 0):
                # self.save(model_dir="D:\\Sang\\hybcast\\hybcast3\\" + self.dataset_name +'\\model\\', epoch=epoch)
                None

        if verbose: print('Train finished! \n')

    def predict(self, X_df, decomposition=False):
        assert type(X_df) == pd.core.frame.DataFrame
        assert 'unique_id' in X_df
        assert self._fitted, "Model not fitted yet"

        self.estransformer.eval()

        # Create fast dataloader
        if self.mc.n_series < self.mc.batch_size_test:
            new_batch_size = self.mc.n_series
        else:
            new_batch_size = self.mc.batch_size_test
        self.train_dataloader.update_batch_size(new_batch_size)
        dataloader = self.train_dataloader

        # Create Y_hat_panel placeholders
        output_size = self.mc.output_size
        n_unique_id = len(dataloader.sort_key['unique_id'])
        panel_unique_id = pd.Series(
            dataloader.sort_key['unique_id']).repeat(output_size)

        #access column with last train date
        panel_last_ds = pd.Series(dataloader.X[:, 2])
        panel_ds = []
        for i in range(len(panel_last_ds)):
            ranges = pd.date_range(start=panel_last_ds[i],
                                   periods=output_size + 1,
                                   freq=self.mc.frequency)
            panel_ds += list(ranges[1:])

        panel_y_hat = np.zeros((output_size * n_unique_id))

        # Predict
        count = 0
        for j in range(dataloader.n_batches):
            batch = dataloader.get_batch()
            batch_size = batch.y.shape[0]

            if self.mc.ensemble:
                y_hat = torch.zeros((5, batch_size, output_size))
                for i in range(5):
                    y_hat[i, :, :] = self.estransformer_ensemble[i].predict(
                        batch)
                y_hat = torch.mean(y_hat, 0)
            else:
                y_hat = self.estransformer.predict(batch)

            y_hat = y_hat.data.cpu().numpy()

            panel_y_hat[count:count +
                        output_size * batch_size] = y_hat.flatten()
            count += output_size * batch_size

        Y_hat_panel_dict = {
            'unique_id': panel_unique_id,
            'ds': panel_ds,
            'y_hat': panel_y_hat
        }

        assert len(panel_ds) == len(panel_y_hat) == len(panel_unique_id)

        Y_hat_panel = pd.DataFrame.from_dict(Y_hat_panel_dict)

        if 'ds' in X_df:
            Y_hat_panel = X_df.merge(Y_hat_panel,
                                     on=['unique_id', 'ds'],
                                     how='left')
        else:
            Y_hat_panel = X_df.merge(Y_hat_panel, on=['unique_id'], how='left')

        self.train_dataloader.update_batch_size(self.mc.batch_size)
        return Y_hat_panel

    def per_series_evaluation(self, dataloader, criterion):
        with torch.no_grad():
            # Create fast dataloader
            if self.mc.n_series < self.mc.batch_size_test:
                new_batch_size = self.mc.n_series
            else:
                new_batch_size = self.mc.batch_size_test
            dataloader.update_batch_size(new_batch_size)

            per_series_losses = []
            for j in range(dataloader.n_batches):
                batch = dataloader.get_batch()
                windows_y, windows_y_hat, _ = self.estransformer(batch)
                loss = criterion(windows_y, windows_y_hat)
                per_series_losses += loss.data.cpu().numpy().tolist()

            dataloader.update_batch_size(self.mc.batch_size)
        return per_series_losses

    def model_evaluation(self, dataloader, criterion):
        with torch.no_grad():
            # Create fast dataloader
            if self.mc.n_series < self.mc.batch_size_test:
                new_batch_size = self.mc.n_series
            else:
                new_batch_size = self.mc.batch_size_test
            dataloader.update_batch_size(new_batch_size)

            model_loss = 0.0
            for j in range(dataloader.n_batches):
                batch = dataloader.get_batch()
                windows_y, windows_y_hat, _ = self.estransformer(batch)
                loss = criterion(windows_y, windows_y_hat)
                model_loss += loss.data.cpu().numpy()

            model_loss /= dataloader.n_batches
            dataloader.update_batch_size(self.mc.batch_size)
        return model_loss

    def evaluate_model_prediction(self,
                                  y_train_df,
                                  X_test_df,
                                  y_test_df,
                                  y_hat_benchmark='y_hat_naive2',
                                  epoch=None):
        assert self._fitted, "Model not fitted yet"

        y_panel = y_test_df.filter(['unique_id', 'ds', 'y'])
        y_benchmark_panel = y_test_df.filter(
            ['unique_id', 'ds', y_hat_benchmark])
        y_benchmark_panel.rename(columns={y_hat_benchmark: 'y_hat'},
                                 inplace=True)
        y_hat_panel = self.predict(X_test_df)
        y_insample = y_train_df.filter(['unique_id', 'ds', 'y'])

        model_owa, model_mase, model_smape = owa(
            y_panel,
            y_hat_panel,
            y_benchmark_panel,
            y_insample,
            seasonality=self.mc.naive_seasonality)

        if self.min_owa > model_owa:
            self.min_owa = model_owa
            if epoch is not None: self.min_epoch = epoch

        print('OWA: {} '.format(np.round(model_owa, 3)))
        print('SMAPE: {} '.format(np.round(model_smape, 3)))
        print('MASE: {} '.format(np.round(model_mase, 3)))

        return model_owa, model_mase, model_smape

    def long_to_wide(self, X_df, y_df):
        data = X_df.copy()
        data['y'] = y_df['y'].copy()
        sorted_ds = np.sort(data['ds'].unique())
        ds_map = {}
        for dmap, t in enumerate(sorted_ds):
            ds_map[t] = dmap
        data['ds_map'] = data['ds'].map(ds_map)
        data = data.sort_values(by=['ds_map', 'unique_id'])
        df_wide = data.pivot(index='unique_id', columns='ds_map')['y']

        x_unique = data[['unique_id', 'x']].groupby('unique_id').first()
        last_ds = data[['unique_id', 'ds']].groupby('unique_id').last()
        assert len(x_unique) == len(data.unique_id.unique())
        df_wide['x'] = x_unique
        df_wide['last_ds'] = last_ds
        df_wide = df_wide.reset_index().rename_axis(None, axis=1)

        ds_cols = data.ds_map.unique().tolist()
        X = df_wide.filter(items=['unique_id', 'x', 'last_ds']).values
        y = df_wide.filter(items=ds_cols).values

        return X, y

    def get_dir_name(self, root_dir=None):
        if not root_dir:
            assert self.mc.root_dir
            root_dir = self.mc.root_dir

        data_dir = self.mc.dataset_name
        model_parent_dir = os.path.join(root_dir, data_dir)
        model_path = ['estransformer_{}'.format(str(self.mc.copy))]
        model_dir = os.path.join(model_parent_dir, '_'.join(model_path))
        return model_dir

    def save(self, model_dir=None, copy=None, epoch=None):
        if copy is not None:
            self.mc.copy = copy

        if not model_dir:
            assert self.mc.root_dir
            model_dir = self.get_dir_name()

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        print('Saving model to:\n {}'.format(model_dir) + '\n')
        torch.save(
            {
                'model_state_dict': self.estransformer.state_dict(),
                'es_optimizer': self.es_optimizer.state_dict(),
                'es_scheduler': self.es_scheduler.state_dict(),
                'transformer_optimizer':
                self.transformer_optimizer.state_dict(),
                'transformer_scheduler':
                self.transformer_scheduler.state_dict(),
                'epoch': epoch
            },
            model_dir + 'model_epoch_' + str(epoch) + '_' + self.dataset_name)

    def load(self, model_dir=None, copy=None, conti_train=False):
        # Run preprocess to instantialize estransformer and its optimizer
        if copy is not None:
            self.mc.copy = copy

        if not model_dir:
            assert self.mc.root_dir
            model_dir = self.get_dir_name()

        temp_model = torch.load(model_dir,
                                map_location=torch.device(self.device))

        # Load model
        self.estransformer.load_state_dict(temp_model['model_state_dict'])

        if conti_train:
            # Instantiate optimizer and scheduler
            self.es_optimizer = optim.Adam(
                params=self.estransformer.es.parameters(),
                lr=self.mc.learning_rate * self.mc.per_series_lr_multip,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps)

            self.es_scheduler = StepLR(
                optimizer=self.es_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=0.9)

            self.transformer_optimizer = optim.Adam(
                params=self.estransformer.transformer.parameters(),
                lr=self.mc.learning_rate,
                betas=(0.9, 0.999),
                eps=self.mc.gradient_eps,
                weight_decay=self.mc.transformer_weight_decay)

            self.transformer_scheduler = StepLR(
                optimizer=self.transformer_optimizer,
                step_size=self.mc.lr_scheduler_step_size,
                gamma=self.mc.lr_decay)

            # Load state
            self.es_optimizer.load_state_dict(temp_model['es_optimizer'])
            self.es_scheduler.load_state_dict(temp_model['es_scheduler'])
            self.transformer_optimizer.load_state_dict(
                temp_model['transformer_optimizer'])
            self.transformer_scheduler.load_state_dict(
                temp_model['transformer_scheduler'])
            self.min_epoch = temp_model['epoch']

            self.train(dataloader=self.train_dataloader,
                       max_epochs=self.mc.max_epochs,
                       warm_start=True,
                       shuffle=True,
                       verbose=True)

    def preprocess(self,
                   X_df,
                   y_df,
                   X_test_df=None,
                   y_test_df=None,
                   y_hat_benchmark='y_hat_naive2',
                   warm_start=False,
                   shuffle=True,
                   verbose=True):
        # Transform long dfs to wide numpy
        assert type(X_df) == pd.core.frame.DataFrame
        assert type(y_df) == pd.core.frame.DataFrame
        assert all([(col in X_df) for col in ['unique_id', 'ds', 'x']])
        assert all([(col in y_df) for col in ['unique_id', 'ds', 'y']])
        if y_test_df is not None:
            assert y_hat_benchmark in y_test_df.columns, 'benchmark is not present in y_test_df, use y_hat_benchmark to define it'

        # Storing dfs for OWA evaluation, initializing min_owa
        self.y_train_df = y_df
        self.X_test_df = X_test_df
        self.y_test_df = y_test_df
        self.min_owa = 4.0
        self.min_epoch = 0

        self.int_ds = isinstance(self.y_train_df['ds'][0],
                                 (int, np.int, np.int64))

        self.y_hat_benchmark = y_hat_benchmark

        X, y = self.long_to_wide(X_df, y_df)
        assert len(X) == len(y)
        assert X.shape[1] >= 3

        # Exogenous variables
        unique_categories = np.unique(X[:, 1])
        self.mc.category_to_idx = dict(
            (word, index) for index, word in enumerate(unique_categories))
        exogenous_size = len(unique_categories)

        # Create batches (device in mc)
        self.train_dataloader = Iterator(mc=self.mc, X=X, y=y)

        # Random Seeds (model initialization)
        torch.manual_seed(self.mc.random_seed)
        np.random.seed(self.mc.random_seed)

        # Initialize model
        n_series = self.train_dataloader.n_series

        self.instantiate_estransformer(exogenous_size, n_series)

        # Validating frequencies
        X_train_frequency = pd.infer_freq(X_df.head()['ds'])
        y_train_frequency = pd.infer_freq(y_df.head()['ds'])
        self.frequencies = [X_train_frequency, y_train_frequency]

        if (X_test_df is not None) and (y_test_df is not None):
            X_test_frequency = pd.infer_freq(X_test_df.head()['ds'])
            y_test_frequency = pd.infer_freq(y_test_df.head()['ds'])
            self.frequencies += [X_test_frequency, y_test_frequency]

        assert len(set(self.frequencies)) <= 1, \
          "Match the frequencies of the dataframes {}".format(self.frequencies)

        self.mc.frequency = self.frequencies[0]
        print("Infered frequency: {}".format(self.mc.frequency))

        # Train model
        self._fitted = True
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True,
        seq_mode=True)  # seq_mode implemented

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 104965071),  #(99965071, 124950714), # 20%를 테스트
        batch_size=TS_BATCH_SZ,
        shuffle=False,
        seq_mode=True)

    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.7)

    CF_model = MLP_Regressor().cuda(GPU)
    CF_checkpoint = torch.load(CF_CHECKPOINT_PATH,
                               map_location='cuda:{}'.format(GPU))
    CF_model.load_state_dict(CF_checkpoint['model_state'])

    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']

    # Train
    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):
            SM.train()
            x, labels, y_mask, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x[:, 10:, :41] = 0  # DELETE METALOG QUE

            # labels_shift: (model can only observe past labels)
            labels_shift = torch.zeros(batch_sz, 20, 1)
            labels_shift[:, 1:, 0] = labels[:, :-1].float()
            #!!! NOLABEL for previous QUERY
            labels_shift[:, 11:, 0] = 0
            # support/query state labels
            sq_state = torch.zeros(batch_sz, 20, 1)
            sq_state[:, :11, 0] = 1
            # compute lastfm_output
            x_audio = x[:, :, 41:].data.clone()
            x_audio = Variable(x_audio, requires_grad=False).cuda(GPU)
            x_emb_lastfm, x_lastfm = CF_model(x_audio)
            x_lastfm = x_lastfm.cpu()
            del x_emb_lastfm

            # Pack x: bx122*20
            x = Variable(
                torch.cat((x_lastfm, x, labels_shift, sq_state),
                          dim=2).permute(0, 2, 1)).cuda(GPU)

            # Forward & update
            y_hat = SM(x)  # y_hat: b*20
            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(
                input=y_hat * y_mask.cuda(GPU),
                target=labels.cuda(GPU) * y_mask.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()

            # Decision
            y_prob = torch.sigmoid(
                y_hat * y_mask.cuda(GPU)).detach().cpu().numpy()  # bx20
            y_pred = (y_prob[:, 10:] >= 0.5).astype(np.int)  # bx10
            y_numpy = labels[:, 10:].numpy()  # bx10
            # Acc
            y_query_mask = y_mask[:, 10:].numpy()
            total_corrects += np.sum((y_pred == y_numpy) * y_query_mask)
            total_query += np.sum(num_query)
            # Restore GPU memory
            del loss, y_hat

            if (session + 1) % 500 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                # Prepare display
                sample_sup = labels[
                    0, :num_support[0]].long().numpy().flatten()
                sample_que = y_numpy[0, :num_query[0]].astype(int)
                sample_pred = y_pred[0, :num_query[0]]
                sample_prob = y_prob[0, 10:10 + num_query[0]]
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))
                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 20000 == 0:
                # Validation
                validate(mval_loader, SM, CF_model, eval_mode=True)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'SM_state': SM.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'SM_opt_state': SM_optim.state_dict(),
                        'SM_sch_state': SM_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, CF_model, eval_mode=True)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'SM_state': SM.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'SM_opt_state': SM_optim.state_dict(),
                'SM_sch_state': SM_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
示例#22
0
def train(data_dir, model_dir, checkpoint_path, pretrained_dvector_path,
          n_steps, save_every, decay_every, seg_len, ratio):
    """Train speaker verifier"""

    # setup
    total_steps = 0
    assert os.path.isdir(model_dir)

    # load data
    dataset = SVDataset(data_dir, seg_len)
    train_index = sample_index(len(dataset), ratio)
    valid_index = [x for x in range(len(dataset)) if x not in train_index]
    train_set = Subset(dataset, train_index)
    valid_set = Subset(dataset, valid_index)
    train_loader = DataLoader(train_set, batch_size=1024, shuffle=True,
                              collate_fn=pad_batch_with_label, drop_last=False)
    valid_loader = DataLoader(valid_set, batch_size=2, shuffle=False,
                              collate_fn=pad_batch_with_label, drop_last=False)
    train_loader_iter = iter(train_loader)
    print(f"Training starts with {train_set.dataset.total} speakers.")

    # load checkpoint
    ckpt = None
    if checkpoint_path is not None:
        ckpt = torch.load(checkpoint_path)
        dvector_path = ckpt["dvector_path"]

    # build network and training tools
    model = SpeakerVerifier(pretrained_dvector_path, dataset.total)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters())
    scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5)
    if ckpt is not None:
        total_steps = ckpt["total_steps"]
        model.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optmizier"])
        scheduler.load_state_dict(ckpt["scheduler"])

    # prepare for traning
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    writer = SummaryWriter(model_dir)
    pbar = tqdm.trange(n_steps)

    # start training
    for step in pbar:

        total_steps += 1

        try:
            batch = next(train_loader_iter)
        except StopIteration:
            train_loader_iter = iter(train_loader)
            batch = next(train_loader_iter)

        data, label = batch
        logits = model((data.to(device)))

        loss = criterion(logits, torch.LongTensor(label).to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_description(f"global = {total_steps}, loss = {loss:.4f}")
        writer.add_scalar("train_loss", loss, total_steps)

        if (step + 1) % save_every == 0:
            ckpt_path = os.path.join(model_dir, f"ckpt-{total_steps}.tar")
            ckpt_dict = {
                "total_steps": total_steps,
                "dvector_path": dvector_path,
                "state_dict": model.state_dict(),
                "criterion": criterion.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }
            torch.save(ckpt_dict, ckpt_path)

        if (step + 1) % save_every == 0:
            val_acc = 0.0
            val_loss = 0.0
            for batch in valid_loader:
                data, label = batch
                with torch.no_grad():
                    logits = model(data.to(device))
                    pred = logits.argmax(dim=1)
                    val_acc += (pred ==
                                torch.LongTensor(label).to(device)).sum().item()
                    val_loss += criterion(logits,
                                          torch.LongTensor(label).to(device)).item()
            val_acc /= len(valid_set)
            val_loss /= len(valid_loader)
            writer.add_scalar("valid_accuracy", val_acc, total_steps)
            writer.add_scalar("valid_loss", val_loss, total_steps)

    print("Training completed.")
示例#23
0
class DeepQNetworkOptionAgent:
    def __init__(self,
                 hex_diffusion,
                 option_num,
                 isoption=False,
                 islocal=True,
                 ischarging=True):
        self.learning_rate = 1e-3  # 1e-4
        self.gamma = GAMMA
        self.start_epsilon = START_EPSILON
        self.final_epsilon = FINAL_EPSILON
        self.epsilon_steps = EPSILON_DECAY_STEPS
        self.memory = BatchReplayMemory(256)
        self.batch_size = BATCH_SIZE
        self.clipping_value = CLIPPING_VALUE
        self.input_dim = INPUT_DIM  # 3 input state
        self.relocation_dim = RELOCATION_DIM  # 7
        self.charging_dim = CHARGING_DIM  # 5
        self.option_dim = OPTION_DIM  # 3
        self.output_dim = DQN_OUTPUT_DIM  # 7+5+3 = 15
        self.num_option = option_num
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.path = OPTION_DQN_SAVE_PATH
        self.state_feature_constructor = FeatureConstructor()

        # init higher level DQN network
        self.q_network = DQN_network(self.input_dim, self.output_dim)
        self.target_q_network = DQN_target_network(self.input_dim,
                                                   self.output_dim)
        self.optimizer = torch.optim.Adam(self.q_network.parameters(),
                                          lr=self.learning_rate)
        self.lr_scheduler = StepLR(optimizer=self.optimizer,
                                   step_size=1000,
                                   gamma=0.99)  # 1.79 e-6 at 0.5 million step.
        self.train_step = 0
        # self.load_network()
        self.q_network.to(self.device)
        self.target_q_network.to(self.device)

        self.decayed_epsilon = self.start_epsilon
        # init option network
        self.record_list = []
        self.global_state_dict = OrderedDict()
        self.time_interval = int(0)
        self.global_state_capacity = 5 * 1440  # we store 5 days' global states to fit replay buffer size.
        self.with_option = isoption
        self.with_charging = ischarging
        self.local_matching = islocal
        self.hex_diffusion = hex_diffusion

        self.h_network_list = []
        self.load_option_networks(self.num_option)
        self.middle_terminal = self.init_terminal_states()

    # def load_network(self, RESUME = False):
    #     if RESUME:
    #         lists = os.listdir(self.path)
    #         lists.sort(key=lambda fn: os.path.getmtime(self.path + "/" + fn))
    #         newest_file = os.path.join(self.path, lists[-1])
    #         path_checkpoint = newest_file
    #         checkpoint = torch.load(path_checkpoint)
    #
    #         self.q_network.load_state_dict(checkpoint['net'])
    #         self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    #
    #         self.train_step = checkpoint['step']
    #         self.copy_parameter()
    #         # self.optimizer.load_state_dict(checkpoint['optimizer'])
    #         print('Successfully load saved network starting from {}!'.format(str(self.train_step)))

    def load_option_networks(self, option_num):
        for option_net_id in range(option_num):
            h_network = OptionNetwork(self.input_dim, 1 + 6 + 5)
            checkpoint = torch.load(
                H_AGENT_SAVE_PATH + 'ht_network_option_%d_1_0_1_11520.pkl' %
                (option_net_id)
            )  # lets try the saved networks after the 14th day.
            h_network.load_state_dict(checkpoint['net'])  # , False
            self.h_network_list.append(h_network.to(self.device))
            print(
                'Successfully load H network {}, total option network num is {}'
                .format(option_net_id, len(self.h_network_list)))

    def init_terminal_states(self):
        """
        we initial a dict to check the sets of terminal hex ids by hour by option id
        :param oid: ID for option network
        :return:
        """
        middle_terminal = defaultdict(list)
        for oid in range(self.num_option):
            with open(TERMINAL_STATE_SAVE_PATH + 'term_states_%d.csv' % oid,
                      'r') as ts:
                next(ts)
                for lines in ts:
                    line = lines.strip().split(',')
                    hr, hid = line  # option_network_id, hour, hex_ids in terminal state
                    middle_terminal[(oid, int(hr))].append(hid)
        return middle_terminal

    def get_actions(self, states, num_valid_relos, assigned_option_ids,
                    global_state):
        """
        option_ids is at the first three slots in the action space, so action id <3 means the corresponding h_network id
        :param global_states:
        :param states: tuple of (tick, hex_id, SOC) and SOC is 0 - 100%
        :param num_valid_relos: only relocation to ADJACENT hexes / charging station is valid
        :states:
        :return: action ids ranges from (0,14) , converted action ids has converted the option ids to hte action ids that are selected by corresponding option networks
        """
        with torch.no_grad():
            self.decayed_epsilon = max(
                self.final_epsilon,
                (self.start_epsilon - self.train_step *
                 (self.start_epsilon - self.final_epsilon) /
                 self.epsilon_steps))

            state_reps = np.array([
                self.state_feature_constructor.construct_state_features(state)
                for state in states
            ])
            hex_diffusions = np.array([
                np.tile(self.hex_diffusion[state[1]], (1, 1, 1))
                for state in states
            ])  # state[1] is hex_id

            mask = self.get_action_mask(
                states,
                num_valid_relos)  # mask for unreachable primitive actions

            option_mask = self.get_option_mask(
                states
            )  # if the state is considered as terminal, we dont use it..
            # terminate_option_mask = torch.from_numpy(option_mask).to(dtype=torch.bool, device=self.device) # the DQN need a tensor as input, so convert it.

            if True:
                full_action_values = np.random.random(
                    (len(states), self.output_dim
                     ))  # generate a matrix with values from 0 to 1
                for i, state in enumerate(states):
                    if assigned_option_ids[i] != -1:
                        full_action_values[i][assigned_option_ids[
                            i]] = 10  # a large enough number to maintain that option if it's terminal state, we next mask it with -1.
                    full_action_values[i][:self.option_dim] = np.negative(
                        option_mask[i, :self.option_dim]
                    )  # convert terminal agents to -1
                    full_action_values[i][(
                        self.option_dim + num_valid_relos[i]):(
                            self.option_dim + self.relocation_dim
                        )] = -1  # mask unreachable neighbors.
                    if state[-1] > HIGH_SOC_THRESHOLD:
                        full_action_values[i][(
                            self.option_dim + self.relocation_dim
                        ):] = -1  # no charging, must relocate
                    elif state[-1] < LOW_SOC_THRESHOLD:
                        full_action_values[i][:(
                            self.option_dim + self.relocation_dim
                        )] = -1  # no relocation, must charge
                action_indexes = np.argmax(full_action_values, 1).tolist()
                # # hard inplace the previously assigned options.
                # action_indexes[np.where(assigned_option_ids!=-1)] = assigned_option_ids[np.where(assigned_option_ids!=-1)]
            # after getting all action ids by DQN, we convert the ones triggered options to the primitive action ids.
            converted_action_indexes = self.convert_option_to_primitive_action_id(
                action_indexes, state_reps, global_state, hex_diffusions, mask)

        return np.array(action_indexes
                        ), np.array(converted_action_indexes) - self.option_dim

    def convert_option_to_primitive_action_id(self, action_indexes, state_reps,
                                              global_state, hex_diffusions,
                                              mask):
        """
        we convert the option ids, e.g., 0,1,2 for each H network, to the generated primitive action ids
        :param action_indexes:
        :param state_reps:
        :param global_state:
        :param hex_diffusions:
        :param mask:
        :return:
        """
        ids_require_option = defaultdict(list)
        for id, action_id in enumerate(action_indexes):
            if action_id < self.num_option:
                ids_require_option[action_id].append(id)
        for option_id in range(self.num_option):
            if ids_require_option[option_id]:
                full_option_values = self.h_network_list[option_id].forward(
                    torch.from_numpy(
                        state_reps[ids_require_option[option_id]]).to(
                            dtype=torch.float32, device=self.device),
                    torch.from_numpy(
                        np.concatenate([
                            np.tile(
                                global_state,
                                (len(ids_require_option[option_id]), 1, 1, 1)),
                            hex_diffusions[ids_require_option[option_id]]
                        ],
                                       axis=1)).to(dtype=torch.float32,
                                                   device=self.device))
                # here mask is of batch x 15 dimension, we omit the first 3 columns, which should be options.
                primitive_action_mask = mask[
                    ids_require_option[option_id], self.
                    option_dim:]  # only primitive actions in option generator
                full_option_values[primitive_action_mask] = -9e10
                option_generated_premitive_action_ids = torch.argmax(
                    full_option_values, dim=1).tolist(
                    )  # let option network select primitive action
                action_indexes[ids_require_option[
                    option_id]] = option_generated_premitive_action_ids + self.option_dim  # 12 to 15
                # cover the option id with the generated primitive action id
        return action_indexes

    def add_global_state_dict(self, global_state_list):
        for tick in global_state_list.keys():
            if tick not in self.global_state_dict.keys():
                self.global_state_dict[tick] = global_state_list[tick]
        if len(self.global_state_dict.keys(
        )) > self.global_state_capacity:  #capacity limit for global states
            for _ in range(
                    len(self.global_state_dict.keys()) -
                    self.global_state_capacity):
                self.global_state_dict.popitem(last=False)

    def add_transition(self, state, action, next_state, reward, terminate_flag,
                       time_steps, valid_action):
        self.memory.push(state, action, next_state, reward, terminate_flag,
                         time_steps, valid_action)

    def batch_sample(self):
        samples = self.memory.sample(
            self.batch_size)  # random.sample(self.memory, self.batch_size)
        return samples
        # state, action, next_state, reward = zip(*samples)
        # return state, action, next_state, reward

    def get_main_Q(self, local_state, global_state):
        return self.q_network.forward(local_state, global_state)

    def get_target_Q(self, local_state, global_state):
        return self.target_q_network.forward(local_state, global_state)

    def copy_parameter(self):
        self.target_q_network.load_state_dict(self.q_network.state_dict())

    def soft_target_update(self, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(
                self.target_q_network.parameters(),
                self.q_network.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def train(self, record_hist):
        self.train_step += 1
        if len(self.memory) < self.batch_size:
            print('batches in replay buffer is {}'.format(len(self.memory)))
            return

        transitions = self.batch_sample()
        batch = self.memory.Transition(*zip(*transitions))

        global_state_reps = [
            self.global_state_dict[int(state[0] / 60)] for state in batch.state
        ]  # should be list of np.array

        global_next_state_reps = [
            self.global_state_dict[int(state_[0] / 60)]
            for state_ in batch.next_state
        ]  # should be list of np.array

        state_reps = [
            self.state_feature_constructor.construct_state_features(state)
            for state in batch.state
        ]
        next_state_reps = [
            self.state_feature_constructor.construct_state_features(state_)
            for state_ in batch.next_state
        ]

        hex_diffusion = [
            np.tile(self.hex_diffusion[state[1]], (1, 1, 1))
            for state in batch.state
        ]
        hex_diffusion_ = [
            np.tile(self.hex_diffusion[state_[1]], (1, 1, 1))
            for state_ in batch.next_state
        ]

        state_batch = torch.from_numpy(np.array(state_reps)).to(
            dtype=torch.float32, device=self.device)
        action_batch = torch.from_numpy(np.array(
            batch.action)).unsqueeze(1).to(dtype=torch.int64,
                                           device=self.device)
        reward_batch = torch.from_numpy(np.array(
            batch.reward)).unsqueeze(1).to(dtype=torch.float32,
                                           device=self.device)
        time_step_batch = torch.from_numpy(np.array(
            batch.time_steps)).unsqueeze(1).to(dtype=torch.float32,
                                               device=self.device)

        next_state_batch = torch.from_numpy(np.array(next_state_reps)).to(
            device=self.device, dtype=torch.float32)
        global_state_batch = torch.from_numpy(
            np.concatenate(
                [np.array(global_state_reps),
                 np.array(hex_diffusion)], axis=1)).to(dtype=torch.float32,
                                                       device=self.device)
        global_next_state_batch = torch.from_numpy(
            np.concatenate(
                [np.array(global_next_state_reps),
                 np.array(hex_diffusion_)],
                axis=1)).to(dtype=torch.float32, device=self.device)

        q_state_action = self.get_main_Q(state_batch,
                                         global_state_batch).gather(
                                             1, action_batch.long())
        # add a mask
        all_q_ = self.get_target_Q(next_state_batch, global_next_state_batch)
        option_mask = self.get_option_mask(batch.next_state)
        mask_ = self.get_action_mask(
            batch.next_state,
            batch.valid_action_num)  # action mask for next state
        all_q_[option_mask] = -9e10
        all_q_[mask_] = -9e10
        maxq = all_q_.max(1)[0].detach().unsqueeze(1)
        y = reward_batch + maxq * torch.pow(self.gamma, time_step_batch)
        loss = F.smooth_l1_loss(q_state_action, y)
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.q_network.parameters(),
                                 self.clipping_value)
        self.optimizer.step()
        self.lr_scheduler.step()

        self.record_list.append([
            self.train_step,
            round(float(loss), 3),
            round(float(reward_batch.view(-1).mean()), 3)
        ])
        self.save_parameter(record_hist)
        print(
            'Training step is {}; Learning rate is {}; Epsilon is {}:'.format(
                self.train_step, self.lr_scheduler.get_lr(),
                round(self.decayed_epsilon, 4)))

    def get_action_mask(self, batch_state, batch_valid_action):
        """
        the action space: the first 3 is for h_network slots, then 7 relocation actions,and 5 nearest charging stations.
        :param batch_state: state
        :param batch_valid_action: info that limites to relocate to reachable neighboring hexes
        :return:
        """
        mask = np.zeros((len(batch_state), self.output_dim))  # (num_state, 15)
        for i, state in enumerate(batch_state):
            mask[i][(self.option_dim + batch_valid_action[i]):(
                self.option_dim + self.relocation_dim
            )] = 1  # limited to relocate to reachable neighboring hexes
            if state[-1] > HIGH_SOC_THRESHOLD:
                mask[i][(
                    self.option_dim +
                    self.relocation_dim):] = 1  # no charging, must relocate
            elif state[-1] < LOW_SOC_THRESHOLD:
                mask[i][:(
                    self.option_dim +
                    self.relocation_dim)] = 1  # no relocation, must charge

        mask = torch.from_numpy(mask).to(dtype=torch.bool, device=self.device)
        return mask

    def get_option_mask(self, states):
        """
        self.is_terminal is to judge if the state is terminal state with the info of hour and hex_id
        :param states:
        :return:
        """
        terminate_option_mask = np.zeros((len(states), self.output_dim))
        for oid in range(self.num_option):
            terminate_option_mask[:, oid] = self.is_terminal(
                states, oid)  # set as 0 if not in terminal set
        for oid in range(self.num_option, self.option_dim):
            terminate_option_mask[:, oid] = 1  # mask out empty options
        return terminate_option_mask

    def is_terminal(self, states, oid):
        """

        :param states:
        :return: a list of bool
        """
        return [
            1 if state in self.middle_terminal[(oid,
                                                int(state[0] // (60 * 60) %
                                                    24))] else 0
            for state in states
        ]

    def is_initial(self, states, oid):
        """

        :param states:
        :return: a list of bool
        """
        return [
            1 if state not in self.middle_terminal[(oid,
                                                    int(state[0] // (60 * 60) %
                                                        24))] else 0
            for state in states
        ]

    def save_parameter(self, record_hist):
        # torch.save(self.q_network.state_dict(), self.dqn_path)
        if self.train_step % SAVING_CYCLE == 0:
            checkpoint = {
                "net": self.q_network.state_dict(),
                # 'optimizer': self.optimizer.state_dict(),
                "step": self.train_step,
                "lr_scheduler": self.lr_scheduler.state_dict()
            }
            if not os.path.isdir(self.path):
                os.mkdir(self.path)
            # print('the path is {}'.format('logs/dqn_model/duel_dqn_%s.pkl'%(str(self.train_step))))
            torch.save(
                checkpoint,
                'logs/test/cnn_dqn_model/dqn_with_option_%d_%d_%d_%d_%s.pkl' %
                (self.num_option, bool(self.with_option),
                 bool(self.with_charging), bool(
                     self.local_matching), str(self.train_step)))
            # record training process (stacked before)
            for item in self.record_list:
                record_hist.writelines('{},{},{}\n'.format(
                    item[0], item[1], item[2]))
            print(
                'Training step: {}, replay buffer size:{}, epsilon: {}, learning rate: {}'
                .format(self.record_list[-1][0], len(self.memory),
                        self.decayed_epsilon, self.lr_scheduler.get_lr()))
            self.record_list = []
示例#24
0
def main():

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = import_module(f'model.{args.arch}').__dict__[args.teacher_model]().to(device)

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir, map_location=device)
    

    if args.arch == 'densenet':
        state_dict_t = {}
        for k, v in ckpt_t['state_dict'].items():
            new_key = '.'.join(k.split('.')[1:])
            if new_key == 'linear.weight':
                new_key = 'fc.weight'
            elif new_key == 'linear.bias':
                new_key = 'fc.bias'
            state_dict_t[new_key] = v
    else:
        state_dict_t = ckpt_t['state_dict']


    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(device)

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = import_module(f'model.{args.arch}').__dict__[args.student_model]().to(device)

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(device) 

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    param_s = [param for name, param in model_s.named_parameters() if 'mask' not in name]
    param_m = [param for name, param in model_s.named_parameters() if 'mask' in name]

    optimizer_s = optim.SGD(param_s, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume, map_location=device)
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']

        model_s.load_state_dict(ckpt['            state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))


    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]
    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print_logger.info(f"Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt', map_location=device)

    model = import_module('utils.preprocess').__dict__[f'{args.arch}'](args, best_model['state_dict_s'])
示例#25
0
class BDDVAgent(
        LearningAgent):  # ++ Extend Learning agent

    def __init__(self, cfg):
        super(BDDVAgent, self).__init__(cfg)

        use_cuda = self._use_cuda  # ++ Parent class already saves some configuration variables
        # ++ All parent variables should start with _.

        # -- Get necessary variables from cfg
        self.cfg = cfg

        # -- Initialize model
        model_class = get_models(cfg.model)

        input_shape = cfg.data_info.image_shape
        input_shape[0] *= cfg.data_info.frame_seq_len
        self.model = model_class[0](cfg, input_shape, cfg.model.nr_bins)
        # ++ All models receive as parameters (configuration namespace, input data size,
        # ++ output data size)

        self._models.append(
            self.model)  # -- Add models & optimizers to base for saving

        # ++ After adding model you can set the agent to cuda mode
        # ++ Parent class already makes some adjustments. E.g. turns model to cuda mode
        if use_cuda:
            self.cuda()

        self._bins = np.arange(-1.0, 1.0, 2.0 / cfg.model.nr_bins)
        # -- Initialize optimizers
        self.optimizer = self.get_optim(cfg.train.algorithm,
                                        cfg.train.algorithm_args, self.model)
        self.scheduler = StepLR(self.optimizer, cfg.train.step_size,
                                cfg.train.decay)
        self._optimizers.append(
            self.optimizer)  # -- Add models & optimizers to base for saving
        # -- Change settings from parent class
        # ++ Parent class automatically initializes 4 metrics: loss/acc for train/test
        # ++ E.g switch metric slope
        self.set_eval_metric_comparison(True)

        # ++ E.g. to add variable name to be saved at checkpoints
        self._save_data.append("scheduler")

        self._tensorboard_model = False
        self.loss_values_train = []
        self.loss_values_test = []

        ##### Make directories and shit for demo########

        self.img_dir = os.getcwd() + "/" + image_dir
        self.act_dir = os.getcwd() + "/" + activations_dir
        self.steer_dir = os.getcwd() + "/" + steer_distr_dir

        if not os.path.isdir(self.img_dir):
            os.mkdir(self.img_dir)
        if not os.path.isdir(self.act_dir):
            os.mkdir(self.act_dir)
        if not os.path.isdir(self.steer_dir):
            os.mkdir(self.steer_dir)
        self.nr_img = 0

        ################################################

        super(BDDVAgent, self).__end_init__()

    def _session_init(self):
        if self._is_train:
            self.optimizer.zero_grad()

    def _train(self, data_loader):
        """
        Considering a dataloader (loaded from config.)
        Implement the training loop.
        :return training loss metric & other information
        """
        optimizer = self.optimizer
        scheduler = self.scheduler
        use_cuda = self._use_cuda
        model = self.model
        criterion = self._get_criterion
        branches = self.model.get_branches(use_cuda)
        train_loss = 0

        progress_bar = ProgressBar(
            'Loss: %(loss).3f', dict(loss=0), len(data_loader))

        for batch_idx, (images, speed, steer_distr, mask) in enumerate(data_loader):
            optimizer.zero_grad()
            images = to_cuda(images, use_cuda)
            speed_target = to_cuda(speed, use_cuda)
            steer_distr = to_cuda(steer_distr, use_cuda)
            inter_output, speed_output, _ = model(images, speed_target)

            output = to_cuda(torch.zeros((mask.shape[0], self.cfg.model.nr_bins)), use_cuda)

            # Reshape mask to use it for selecting frames at each moment
            mask = mask.reshape((-1, mask.shape[0]))

            for i in range(0, len(branches)):
                # Hardcode for non-temporal case for now
                filter_ = (mask[0] == i)
                if not np.all(filter_ == False):
                    output[filter_] = branches[i](inter_output[filter_])

            loss = criterion(output, speed_output, speed_target, steer_distr)

            loss.backward()
            train_loss += loss.item()

            optimizer.step()
            scheduler.step()
            progress_bar.update(
                batch_idx, dict(loss=(train_loss / (batch_idx + 1))))

            self.loss_values_train.append(loss.item())

            ################### TensorBoard Shit ################################

            #loss function

            #self._writer.add_scalar(
            #    "loss_function", loss.item(),
            #    batch_idx + self._train_epoch * len(data_loader))

            #model
            #if self._tensorboard_model is False:
            #    self._tensorboard_model = True
            #    self._writer.add_graph(model, (images, speed_target))

            #####################################################################

        progress_bar.finish()

        return train_loss, {}

    def _get_criterion(self, branch_outputs, speed_outputs,
                       speed_target, steer_distr):
        loss1_steer = torch.nn.functional.mse_loss(
            branch_outputs, steer_distr, size_average=False)

        loss1 = loss1_steer

        loss2 = (speed_outputs - speed_target) * (speed_outputs - speed_target)
        loss2 = loss2.sum()# / branch_outputs.shape[0]

        loss = (0.95 * loss1 + 0.05 * loss2) / branch_outputs.shape[0]
        return loss

    def _test(self, data_loader):
        """
        Considering a dataloader (loaded from config.)
        Implement the testing loop.
        """
        use_cuda = self._use_cuda
        model = self.model
        criterion = self._get_criterion
        branches = self.model.get_branches(use_cuda)
        test_loss = 0

        progress_bar = ProgressBar(
            'Loss: %(loss).3f', dict(loss=0), len(data_loader))

        for batch_idx, (images, speed, steer_distr, mask) in enumerate(data_loader):
            images = to_cuda(images, use_cuda)
            speed_target = to_cuda(speed, use_cuda)
            steer_distr = to_cuda(steer_distr, use_cuda)

            inter_output, speed_output, _ = model(images, speed_target)

            output = to_cuda(torch.zeros((mask.shape[0], self.cfg.model.nr_bins)), use_cuda)

            # Reshape mask to use it for selecting frames at each moment
            mask = mask.reshape((-1, mask.shape[0]))

            for i in range(0, len(branches)):
                # Hardcode for non-temporal case for now
                filter_ = (mask[0] == i)
                if not np.all(filter_ == False):
                    output[filter_] = branches[i](inter_output[filter_])

            loss = criterion(output, speed_output, speed_target, steer_distr)

            test_loss += loss.item()

            self.loss_values_test.append(loss.item())

            progress_bar.update(
                batch_idx, dict(loss=(test_loss / (batch_idx + 1))))

        progress_bar.finish()

        return test_loss, None, {}

    def _get_steer_from_bins(self, steer_vector):
        # Pass the steer values through softmax_layer and get the bin index
        bin_index = torch.nn.functional.softmax(steer_vector).argmax()
        #bin_index = steer_vector.argmax()
        plt.plot(self._bins + 1.0 / len(self._bins),
                 torch.nn.functional.softmax(steer_vector).data[0].numpy())
        plt.show(block=False)

        plt.draw()
        plt.pause(0.0001)
        #plt.savefig(self.steer_dir + "/distr_" + str(self.nr_img) + ".png")
        plt.gcf().clear()
        #get steer_value from bin
        return self._bins[bin_index] + 1.0 / len(self._bins)

    def _show_activation_image(self, raw_activation, image_activation):
        activation_map = raw_activation.data[0, 0].cpu().numpy()
        activation_map = (activation_map - np.min(activation_map)
                          ) / np.max(activation_map) - np.min(activation_map)

        activation_map = (activation_map * 255.0)

        if image_activation.shape[0] != activation_map.shape[0]:
            activation_map = scipy.misc.imresize(
                activation_map,
                [image_activation.shape[0], image_activation.shape[1]])

        image_activation[:, :, 1] += activation_map.astype(np.uint8)
        activation_map = cv2.applyColorMap(
            activation_map.astype(np.uint8), cv2.COLORMAP_JET)

        image_activation = cv2.resize(image_activation, (720, 460),
                                      cv2.INTER_AREA)
        image_activation = cv2.cvtColor(image_activation, cv2.COLOR_RGB2BGR)

        activation_map = cv2.resize(activation_map, (720, 460), cv2.INTER_AREA)

        cv2.imshow("activation",
                   np.concatenate((image_activation, activation_map), axis=1))

        if cv2.waitKey(1) & 0xFF == ord('q'):
            return

    def run_image(self, image_raw, speed, cmd):
        self.set_eval_mode()

        image = np.transpose(image_raw, (2, 0, 1)).astype(np.float32)
        image = np.multiply(image, 1.0 / 127.5) - 1
        image = to_cuda(torch.from_numpy(image), self._use_cuda)
        image = image.unsqueeze(0)
        speed = to_cuda(torch.Tensor([speed / 90.0]), self._use_cuda)
        speed = speed.unsqueeze(0)

        branches = self.model.get_branches(self._use_cuda)

        inter_output, speed_output, activation_map = self.model(image, speed)

        output = branches[cmd](inter_output)

        steer_angle = self._get_steer_from_bins(output)
        speed_output = speed_output.data.cpu()[0].numpy()
        return steer_angle, speed_output[0] * 90, activation_map

    def run_1step(self, image_raw, speed, cmd):
        image = np.transpose(image_raw, (2, 0, 1)).astype(np.float32)
        image = np.multiply(image, 1.0 / 127.5) - 1
        image = to_cuda(torch.from_numpy(image), self._use_cuda)
        image = image.unsqueeze(0)
        speed = to_cuda(torch.Tensor([speed / 90.0]), self._use_cuda)
        speed = speed.unsqueeze(0)

        branches = self.model.get_branches(self._use_cuda)

        inter_output, speed_output, activation_map = self.model(image, speed)

        if self.cfg.activations:
            self._show_activation_image(activation_map, np.copy(image_raw))

        output = branches[cmd](inter_output)

        steer_angle = self._get_steer_from_bins(output)
        speed_output = speed_output.data.cpu()[0].numpy()
        return steer_angle, speed_output[0] * 90

    def _eval_episode(self, file_name):
        video_file = file_name[0]
        info_file = file_name[1]
        info = pd.read_csv(info_file)

        nr_images = len(info)
        previous_speed = info['linear_speed'][0]

        general_mse = steer_mse = 0

        # Determine steering angles and commands
        helper = DatasetHelper(None, None, None, self.cfg.dataset)
        frame_indices = range(len(info))
        course = info['course']
        linear_speed = info['linear_speed']
        angles, cmds = helper.get_steer(frame_indices, course, linear_speed)

        # Open video to read frames
        vid = cv2.VideoCapture(video_file)

        for index in range(nr_images):
            ret, frame = vid.read()
            if not ret:
                print('Could not retrieve frame')
                return None, None

            gt_speed = linear_speed[index]
            gt_steer = angles[index]

            predicted_steer, predicted_speed = self.run_1step(
                frame, previous_speed, cmds[index])

            steer = (predicted_steer - gt_steer) * (predicted_steer - gt_steer)
            speed = (predicted_speed - gt_speed) * (predicted_speed - gt_speed)
            steer_mse += steer

            general_mse += 0.05 * speed + 0.95 * steer

            log.info("Frame number {}".format(index))
            log.info("Steer: predicted {}, ground_truth {}".format(
                predicted_steer, gt_steer))

            log.info("Speed: predicted {}, ground_truth {}".format(
                predicted_speed, gt_speed))

            previous_speed = gt_speed

        vid.release()

        general_mse /= float(nr_images)
        steer_mse /= float(nr_images)

        return general_mse, steer_mse

    def eval_agent(self):
        self.set_eval_mode()

        f = open(self._save_path + "/eval_results.txt", "wt")
        data_files = sorted(os.listdir(self.cfg.dataset.dataset_test_path))
        video_files = []
        for file in data_files:
            info_file = file.split('.')[0] + '.csv'
            video_files.append((os.path.join(self.cfg.dataset.dataset_test_path, file),
                os.path.join(self.cfg.dataset.info_test_path, info_file)))
        eval_results = []

        mean_mse = mean_steer = 0
        for video_file in video_files:
            general_mse, steer_mse = self._eval_episode(video_file)
            eval_results.append((general_mse, steer_mse))
            mean_mse += general_mse
            mean_steer += steer_mse

            f.write(
                "****************Evaluated {} *******************\n".format(
                    video_file))
            f.write("Mean squared error is {}\n".format(str(general_mse)))
            f.write("Mean squared error for steering is {}\n".format(
                str(steer_mse)))
            f.write("************************************************\n\n")
            f.flush()

        mean_mse /= float(len(video_files))
        mean_steer /= float(len(video_files))

        std_mse = std_steer = 0
        for i in range(len(video_files)):
            std_mse += (eval_results[i][0] - mean_mse) * (
                eval_results[i][0] - mean_mse)
            std_steer += (eval_results[i][2] - mean_steer) * (
                eval_results[i][2] - mean_steer)

        std_mse /= float(len(video_files))
        std_steer /= float(len(video_files))

        std_mse = math.sqrt(std_mse)
        std_steer = math.sqrt(std_steer)

        f.write("****************Final Evaluation *******************\n")
        f.write("Mean squared error is {} with standard deviation {}\n".format(
            str(mean_mse), str(std_mse)))
        f.write(
            "Mean squared error for steering is {} with standard deviation {}\n".
            format(str(steer_mse), str(std_steer)))
        f.write("******************************************************")
        f.flush()
        f.close()

    def _control_function(self, image_input_raw, real_speed, control_input):
        """
        Implement for carla simulator run.
        :return: steer, acc, brake
        """
        print("Control input is {}".format(control_input))

        image_input = scipy.misc.imresize(image_input_raw, [
            self.cfg.data_info.image_shape[1],
            self.cfg.data_info.image_shape[2]
        ])
        image_input = np.transpose(image_input, (2, 0, 1)).astype(np.float32)
        image_input = np.multiply(image_input, 1.0 / 127.5) - 1.0
        image_input = torch.from_numpy(image_input)
        image_input = image_input.unsqueeze(0)
        speed = torch.Tensor([real_speed / 25.0])
        speed = speed.unsqueeze(0)

        branches = self.model.get_branches(self._use_cuda)

        inter_output, predicted_speed, activation_map = self.model(
            image_input, speed)

        if self.cfg.activations:
            self._show_activation_image(activation_map,
                                        np.copy(image_input_raw))

        if control_input == 2 or control_input == 0:
            output = branches[1](inter_output)
        elif control_input == 3:
            output = branches[2](inter_output)
        elif control_input == 4:
            output = branches[3](inter_output)
        else:
            output = branches[4](inter_output)

        steer = self._get_steer_from_bins(output[:, :-2])
        output = output.data.cpu()[0].numpy()
        acc, brake = output[-2], output[-1]


        predicted_speed = predicted_speed.data[0].numpy()
        real_predicted = predicted_speed * 25.0

        if real_speed < 2.0 and real_predicted > 3.0:
            acc = 1 * (5.6 / 25.0 - real_speed / 25.0) + acc
            brake = 0.0

        self.nr_img += 1
        return steer, acc, brake

    def _set_eval_mode(self):
        """
        Custom configuration when changing to evaluation mode
        """
        if self.cfg.activations:
            self.model.set_forward('forward_deconv')
        else:
            self.model.set_forward('forward_simple')
        if self._use_cuda:
            self.cuda()

    def _set_train_mode(self):
        """
        Custom configuration when changing to train mode
        """
        self.model.set_forward('forward_simple')
        if self._use_cuda:
            self.cuda()

    def _save(self, save_data, path):
        """
        Called when saving agent state. Agent already saves variables defined in the list
        self._save_data and other default options.
        :param save_data: Pre-loaded dictionary with saved data. Append here other data
        :param path: Path to folder where other custom data can be saved
        :return: should return default save_data dictionary to be saved
        """
        save_data['scheduler_state'] = self.scheduler.state_dict()
        save_data['train_epoch'] = self._train_epoch
        save_data['loss_value_train'] = self.loss_values_train
        save_data['loss_value_test'] = self.loss_values_test

        return save_data

    def _resume(self, agent_check_point_path, saved_data):
        """
        Custom resume scripts should be implemented here
        :param agent_check_point_path: Path of the checkpoint resumed
        :param saved_data: loaded checkpoint data (dictionary of variables)
        """
        self.scheduler.load_state_dict(saved_data['scheduler_state'])
        self.scheduler.optimizer = self.optimizer
        self.model = self._models[0]
        self.optimizer = self._optimizers[0]
        self._train_epoch = saved_data['train_epoch']
        self.loss_values_train = saved_data['loss_value_train']
        self.loss_values_test = saved_data['loss_value_test']
        if not self._use_cuda:
            self.model.cpu()
示例#26
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.height,
                                                args.width,
                                                args.train_resizing,
                                                random_horizontal_flip=True,
                                                random_color_jitter=False,
                                                random_gray_scale=False,
                                                random_erasing=True)
    val_transform = utils.get_val_transform(args.height, args.width)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    working_dir = osp.dirname(osp.abspath(__file__))
    source_root = osp.join(working_dir, args.source_root)
    target_root = osp.join(working_dir, args.target_root)

    # source dataset
    source_dataset = datasets.__dict__[args.source](
        root=osp.join(source_root, args.source.lower()))
    sampler = RandomMultipleGallerySampler(source_dataset.train,
                                           args.num_instances)
    train_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=train_transform),
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     sampler=sampler,
                                     pin_memory=True,
                                     drop_last=True)
    train_source_iter = ForeverDataIterator(train_source_loader)
    cluster_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    val_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(source_dataset.query) | set(source_dataset.gallery)),
        root=source_dataset.images_dir,
        transform=val_transform),
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=True)

    # target dataset
    target_dataset = datasets.__dict__[args.target](
        root=osp.join(target_root, args.target.lower()))
    cluster_target_loader = DataLoader(convert_to_pytorch_dataset(
        target_dataset.train,
        root=target_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    test_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(target_dataset.query) | set(target_dataset.gallery)),
        root=target_dataset.images_dir,
        transform=val_transform),
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    n_s_classes = source_dataset.num_train_pids
    args.n_classes = n_s_classes + len(target_dataset.train)
    args.n_s_classes = n_s_classes
    args.n_t_classes = len(target_dataset.train)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    pool_layer = nn.Identity() if args.no_pool else None
    model = ReIdentifier(backbone,
                         args.n_classes,
                         finetune=args.finetune,
                         pool_layer=pool_layer)
    features_dim = model.features_dim

    idm_bn_names = filter_layers(args.stage)
    convert_dsbn_idm(model, idm_bn_names, idm=False)

    model = model.to(device)
    model = DataParallel(model)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])

    # analysis the model
    if args.phase == 'analysis':
        # plot t-SNE
        utils.visualize_tsne(source_loader=val_loader,
                             target_loader=test_loader,
                             model=model,
                             filename=osp.join(logger.visualize_directory,
                                               'analysis', 'TSNE.pdf'),
                             device=device)
        # visualize ranked results
        visualize_ranked_results(test_loader,
                                 model,
                                 target_dataset.query,
                                 target_dataset.gallery,
                                 device,
                                 visualize_dir=logger.visualize_directory,
                                 width=args.width,
                                 height=args.height,
                                 rerank=args.rerank)
        return

    if args.phase == 'test':
        print("Test on target domain:")
        validate(test_loader,
                 model,
                 target_dataset.query,
                 target_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        return

    # create XBM
    dataset_size = len(source_dataset.train) + len(target_dataset.train)
    memory_size = int(args.ratio * dataset_size)
    xbm = XBM(memory_size, features_dim)

    # initialize source-domain class centroids
    source_feature_dict = extract_reid_feature(cluster_source_loader,
                                               model,
                                               device,
                                               normalize=True)
    source_features_per_id = {}
    for f, pid, _ in source_dataset.train:
        if pid not in source_features_per_id:
            source_features_per_id[pid] = []
        source_features_per_id[pid].append(source_feature_dict[f].unsqueeze(0))
    source_centers = [
        torch.cat(source_features_per_id[pid], 0).mean(0)
        for pid in sorted(source_features_per_id.keys())
    ]
    source_centers = torch.stack(source_centers, 0)
    source_centers = F.normalize(source_centers, dim=1)
    model.module.head.weight.data[0:n_s_classes].copy_(
        source_centers.to(device))

    # save memory
    del source_centers, cluster_source_loader, source_features_per_id

    # define optimizer and lr scheduler
    optimizer = Adam(model.module.get_parameters(base_lr=args.lr,
                                                 rate=args.rate),
                     args.lr,
                     weight_decay=args.weight_decay)
    lr_scheduler = StepLR(optimizer, step_size=args.step_size, gamma=0.1)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    # start training
    best_test_mAP = 0.
    for epoch in range(args.start_epoch, args.epochs):
        # run clustering algorithm and generate pseudo labels
        train_target_iter = run_dbscan(cluster_target_loader, model,
                                       target_dataset, train_transform, args)

        # train for one epoch
        print(lr_scheduler.get_lr())
        train(train_source_iter, train_target_iter, model, optimizer, xbm,
              epoch, args)

        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
            # remember best mAP and save checkpoint
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch
                }, logger.get_checkpoint_path(epoch))
            print("Test on target domain...")
            _, test_mAP = validate(test_loader,
                                   model,
                                   target_dataset.query,
                                   target_dataset.gallery,
                                   device,
                                   cmc_flag=True,
                                   rerank=args.rerank)
            if test_mAP > best_test_mAP:
                shutil.copy(logger.get_checkpoint_path(epoch),
                            logger.get_checkpoint_path('best'))
            best_test_mAP = max(test_mAP, best_test_mAP)

        # update lr
        lr_scheduler.step()

    print("best mAP on target = {}".format(best_test_mAP))
    logger.close()
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    #loader = import_module('data.' + args.dataset).Data(args)

    # while(1):
    #     a=1

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_train')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_val')
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]))

    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=batch_sizes, shuffle=True,
    #     num_workers=8, pin_memory=True, sampler=None)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_val_rec')


    train_queue = getTrainValDataset(traindir, valdir, batch_size=batch_size, val_batch_size=batch_size,
                                     num_shards=num_gpu, workers=num_workers)
    valid_queue = getTestDataset(valdir, test_batch_size=batch_size, num_shards=num_gpu,
                                 workers=num_workers)

    #loader = cifar100(args)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))


    if args.pruned:
        mask = checkpoint['mask']
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    # model = torchvision.models.resnet18()

    with torch.cuda.device(0):
        flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
        print('Flops:  ' + flops)
        print('Params: ' + params)
    pruned_dir = args.pruned_dir
    checkpoint_pruned = torch.load(pruned_dir, map_location=torch.device(f"cuda:{args.gpus[0]}"))
    model = torch.nn.DataParallel(model)
    #
    # new_state_dict_pruned = OrderedDict()
    # for k, v in checkpoint_pruned.items():
    #     name = k[7:]
    #     new_state_dict_pruned[name] = v
    # model.load_state_dict(new_state_dict_pruned)

    model.load_state_dict(checkpoint_pruned['state_dict_s'])

    test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test)
    logging.info('Simply test after prune: %e ', test_prec1)
    logging.info('Model size: %e ', get_parameters_size(model)/1e6)

    exit()

    if args.test_only:
        return
    param_s = [param for name, param in model.named_parameters() if 'mask' not in name]
    #optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.00001, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer = optim.SGD(param_s, lr=1e-5, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.num_epochs))


    model_kd = None
    if kd_flag:
        model_kd = ResNet101()
        ckpt_kd = torch.load('resnet101.t7', map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_kd = ckpt_kd['net']
        new_state_dict_kd = OrderedDict()
        for k, v in state_dict_kd.items():
            name = k[7:]
            new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
        model_kd.load_state_dict(new_state_dict_kd)
        model_kd = model_kd.to(args.gpus[1])

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict_s'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))
    #print(model.named_parameters())
    #for name, param in model.named_parameters():
        #print(name)
    for epoch in range(start_epoch, 60):
        scheduler.step()#scheduler.step(epoch)
        t1 = time.time()
        train(args, train_queue, model, criterion, optimizer, writer_train, epoch, model_kd)
        test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test, epoch)
        t2 = time.time()
        print(epoch, t2 - t1)
        logging.info('TEST Top1: %e Top5: %e ', test_prec1, test_prec5)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
        logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, is_best)
        train_queue.reset()
        valid_queue.reset()

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
示例#28
0
def main(args):
    #
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    set_random_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    if args.dataset == 'mnist':
        train_data = get_dataset('mnist-train',  args.dataroot)
        test_data = get_dataset('mnist-test',  args.dataroot)
        train_tr = test_tr = get_transform('mnist_normalize')

    if args.dataset == 'cifar10':
        train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize'
        train_data = get_dataset('cifar10-train',  args.dataroot)
        test_data = get_dataset('cifar10-test',  args.dataroot)
        train_tr = get_transform(train_tr_name)
        test_tr = get_transform('cifar_normalize')
        
    if args.dataset == 'cifar-fs-train':
        train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize'
        train_data = get_dataset('cifar-fs-train-train',  args.dataroot)
        test_data = get_dataset('cifar-fs-train-test',  args.dataroot)
        train_tr = get_transform(train_tr_name)
        test_tr = get_transform('cifar_normalize')

    if args.dataset == 'miniimagenet':
        train_data = get_dataset('miniimagenet-train-train', args.dataroot)
        test_data = get_dataset('miniimagenet-train-test', args.dataroot)
        train_tr = get_transform('cifar_augment_normalize_84' if args.data_augmentation else 'cifar_normalize')
        test_tr = get_transform('cifar_normalize')
    

    model = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device)
    if args.ckpt_path != '':
        loaded = torch.load(args.ckpt_path)
        model.load_state_dict(loaded)
        ipdb.set_trace()
    if args.eval:
        acc = test(args, model, device, test_loader, args.n_eval_batches)
        print("Eval Acc: ", acc)
        sys.exit()

    # Trace logging
    mkdir(args.output_dir)
    eval_fieldnames = ['global_iteration','val_acc','train_acc']
    eval_logger = CSVLogger(every=1,
                                 fieldnames=eval_fieldnames,
                                 resume=args.resume,
                                 filename=os.path.join(args.output_dir, 'eval_log.csv'))
    wandb.run.name = os.path.basename(args.output_dir)
    wandb.run.save()
    wandb.watch(model)

    if args.optim == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
    elif args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
    if args.dataset == 'mnist':
        scheduler = StepLR(optimizer, step_size=1, gamma=.7)
    else:
        scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

    start_epoch = 1
    if args.resume:
        last_ckpt_path = os.path.join(args.output_dir, 'last_ckpt.pt')
        if os.path.exists(last_ckpt_path):
            loaded = torch.load(last_ckpt_path)
            model.load_state_dict(loaded['model_sd'])
            optimizer.load_state_dict(loaded['optimizer_sd'])
            scheduler.load_state_dict(loaded['scheduler_sd'])
            start_epoch = loaded['epoch']

    # It's important to set seed again before training b/c dataloading code
    # might have reset the seed.
    set_random_seed(args.seed)
    best_val = 0
    if args.db: 
        scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4], gamma=0.1)
        args.epochs = 5
    for epoch in range(start_epoch, args.epochs + 1):
        if epoch % args.ckpt_every == 0:
            torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_{epoch}.pt"))

        stats_dict = {'global_iteration':epoch}
        val = stats_dict['val_acc'] = test(args, model, device, test_data, test_tr, args.n_eval_batches)
        stats_dict['train_acc'] = test(args, model, device, train_data, test_tr, args.n_eval_batches)
        grid = make_grid(torch.stack([train_tr(x) for x in train_data['x'][:30]]), nrow=6).permute(1,2,0).numpy()
        img_dict = {"examples": [wandb.Image(grid, caption="Data batch")]}
        wandb.log(stats_dict)
        wandb.log(img_dict)
        eval_logger.writerow(stats_dict)
        plot_csv(eval_logger.filename, os.path.join(args.output_dir, 'iteration_plots.png'))

        train(args, model, device, train_data, train_tr, optimizer, epoch)
        
        scheduler.step(epoch)

        if val > best_val: 
            best_val = val
            torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_best.pt"))

        # For `resume`
        model.cpu()
        torch.save({
            'model_sd': model.state_dict(),
            'optimizer_sd': optimizer.state_dict(), 
            'scheduler_sd': scheduler.state_dict(), 
            'epoch': epoch + 1
            }, os.path.join(args.output_dir, "last_ckpt.pt"))
        model.to(device)
示例#29
0
def main():  
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True,
                                      data_sel=(0, 99965071), # 80% 트레인
                                      batch_size=TR_BATCH_SZ,
                                      shuffle=True,
                                      seq_mode=True) # seq_mode implemented  
    
    mval_loader  = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True, # True, because we use part of trainset as testset
                                      data_sel=(99965071, 104965071),#(99965071, 124950714), # 20%를 테스트
                                      batch_size=TS_BATCH_SZ,
                                      shuffle=False,
                                      seq_mode=True) 
    
    # Load Teacher net
    SMT = SeqModel().cuda(GPU) 
    checkpoint = torch.load(FPATH_T_NET_CHECKPOINT, map_location='cuda:{}'.format(GPU))
    tqdm.write("Loading saved teacher model from '{0:}'... loss: {1:.6f}".format(FPATH_T_NET_CHECKPOINT,checkpoint['loss']))
    SMT.load_state_dict(checkpoint['SM_state'])
    
    SMT_Enc  = nn.Sequential(*list(SMT.children())[:1]).cuda(GPU)
    #SMT_EncFeat = nn.Sequential(*list(SMT.children())[:2])
    
    
    # Init Student net --> copy classifier from the Teacher net
    SM = SeqModel_Student().cuda(GPU)
    SM.feature = deepcopy(SMT.feature)
#    for p in list(SM.feature.parameters()):
#        p.requires_grad = False
    SM.classifier = deepcopy(SMT.classifier)
#    SM.classifier.weight.requires_grad = False
#    SM.classifier.bias.requires_grad = False
    SM = SM.cuda(GPU)
    Distill_parameters = SM.enc.parameters()
    Classifier_parameters = [{'params': SM.feature.parameters()},
                              {'params': SM.classifier.parameters()}]
    
    SM_optim = torch.optim.Adam(Distill_parameters, lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.9)  
    SM2_optim = torch.optim.Adam(Classifier_parameters, lr=LEARNING_RATE)
    
    
    
    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0        
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),key=os.path.getctime)  
        checkpoint = torch.load(latest_fpath, map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(latest_fpath,checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']
        
    # Train    
    for epoch in trange(START_EPOCH, EPOCHS, desc='epochs', position=0, ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query    = 0
        total_trloss   = 0
        for session in trange(len(tr_sessions_iter), desc='sessions', position=1, ascii=True):
            SMT.eval(); # Teacher-net
            SM.train(); # Student-net
            x, labels, y_mask, num_items, index = tr_sessions_iter.next() # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS 
            
            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...        
            num_support = num_items[:,0].detach().numpy().flatten() # If num_items was odd number, query has one more item. 
            num_query   = num_items[:,1].detach().numpy().flatten()
            batch_sz    = num_items.shape[0]
            
            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x = x.permute(0,2,1) # bx70*20
            
            # x_feat_T: Teacher-net input, x_feat_S: Student-net input(que-log is excluded)
            x_feat_T = torch.zeros(batch_sz, 72, 20)
            x_feat_T[:,:70,:] = x.clone()
            x_feat_T[:, 70,:10] = 1 # Sup/Que state indicator  
            x_feat_T[:, 71,:10] = labels[:,:10].clone()
                        
            x_feat_S = x_feat_T.clone()
            x_feat_S[:, :41, 10:] = 0 # remove que-log
            
            x_feat_T = x_feat_T.cuda(GPU)
            x_feat_S = Variable(x_feat_S).cuda(GPU)
            
            
            # Target: Prepare Teacher's intermediate output 
            enc_target = SMT_Enc(x_feat_T)
            #target = SMT_EncFeat(x_feat_T)
            
            # y
            y = labels.clone()
            
            # y_mask
            y_mask_que = y_mask.clone()
            y_mask_que[:,:10] = 0
            
            # Forward & update
            y_hat_enc, y_hat = SM(x_feat_S) # y_hat: b*20
            
            # Calcultate Distillation loss
            loss1 = F.binary_cross_entropy_with_logits(input=y_hat_enc, target=torch.sigmoid(enc_target.cuda(GPU)))
            loss2 = F.l1_loss(input=y_hat_enc, target=enc_target.cuda(GPU))
            loss = loss1+loss2
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward(retain_graph=True)
            # Update Enc
            SM_optim.step()
            
            # Calculate Classifier loss
            loss_c = F.binary_cross_entropy_with_logits(input=y_hat*y_mask_que.cuda(GPU), target=y.cuda(GPU)*y_mask_que.cuda(GPU))
            SM.zero_grad()
            loss_c.backward()
            # Update Classifier and feature
            SM2_optim.step()
            
            # Decision
            SM.eval();
            y_prob = torch.sigmoid(y_hat*y_mask_que.cuda(GPU)).detach().cpu().numpy() # bx20               
            y_pred = (y_prob[:,10:]>0.5).astype(np.int) # bx10
            y_numpy = labels[:,10:].numpy() # bx10
            # Acc
            total_corrects += np.sum((y_pred==y_numpy)*y_mask_que[:,10:].numpy())
            total_query += np.sum(num_query)
            
            # Restore GPU memory
            del loss, loss_c, y_hat, y_hat_enc
    
            if (session+1)%500 == 0:
                hist_trloss.append(total_trloss/900)
                hist_tracc.append(total_corrects/total_query)
                # Prepare display
                sample_sup = labels[0,(10-num_support[0]):10].long().numpy().flatten() 
                sample_que = y_numpy[0,:num_query[0]].astype(int)
                sample_pred = y_pred[0,:num_query[0]]
                sample_prob = y_prob[0,10:10+num_query[0]]

                tqdm.write("S:" + np.array2string(sample_sup) +'\n'+
                           "Q:" + np.array2string(sample_que) + '\n' +
                           "P:" + np.array2string(sample_pred) + '\n' +
                           "prob:" + np.array2string(sample_prob))
                tqdm.write("tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query    = 0
                total_trloss   = 0
                
            
            if (session+1)%25000 == 0:
                 # Validation
                 validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                 # Save
                 torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                             'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                             'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                    'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
示例#30
0
def train(train_dir, model_dir, config_path, checkpoint_path,
          n_steps, save_every, test_every, decay_every,
          n_speakers, n_utterances, seg_len):
    """Train a d-vector network."""

    # setup
    total_steps = 0

    # load data
    dataset = SEDataset(train_dir, n_utterances, seg_len)
    train_set, valid_set = random_split(dataset, [len(dataset)-2*n_speakers,
                                                  2*n_speakers])
    train_loader = DataLoader(train_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    train_iter = iter(train_loader)

    assert len(train_set) >= n_speakers
    assert len(valid_set) >= n_speakers
    print(f"Training starts with {len(train_set)} speakers. "
          f"(and {len(valid_set)} speakers for validation)")

    # build network and training tools
    dvector = DVector().load_config_file(config_path)
    criterion = GE2ELoss()
    optimizer = SGD(list(dvector.parameters()) +
                    list(criterion.parameters()), lr=0.01)
    scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5)

    # load checkpoint
    if checkpoint_path is not None:
        ckpt = torch.load(checkpoint_path)
        total_steps = ckpt["total_steps"]
        dvector.load_state_dict(ckpt["state_dict"])
        criterion.load_state_dict(ckpt["criterion"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])

    # prepare for training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dvector = dvector.to(device)
    criterion = criterion.to(device)
    writer = SummaryWriter(model_dir)
    pbar = tqdm.trange(n_steps)

    # start training
    for step in pbar:

        total_steps += 1

        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)

        embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)

        loss = criterion(embd)

        optimizer.zero_grad()
        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            list(dvector.parameters()) + list(criterion.parameters()), max_norm=3)
        dvector.embedding.weight.grad.data *= 0.5
        criterion.w.grad.data *= 0.01
        criterion.b.grad.data *= 0.01

        optimizer.step()
        scheduler.step()

        pbar.set_description(f"global = {total_steps}, loss = {loss:.4f}")
        writer.add_scalar("Training loss", loss, total_steps)
        writer.add_scalar("Gradient norm", grad_norm, total_steps)

        if (step + 1) % test_every == 0:
            batch = next(iter(valid_loader))
            embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)
            loss = criterion(embd)
            writer.add_scalar("validation loss", loss, total_steps)

        if (step + 1) % save_every == 0:
            ckpt_path = os.path.join(model_dir, f"ckpt-{total_steps}.tar")
            ckpt_dict = {
                "total_steps": total_steps,
                "state_dict": dvector.state_dict(),
                "criterion": criterion.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }
            torch.save(ckpt_dict, ckpt_path)

    print("Training completed.")