예제 #1
0
        #     model.rel_weights.data.masked_fill_(sign_changed, 0)
        # log
        logger.log('train_iter.mse_dyn', mse_dyn.item())
        logs_train['mse_dyn'] += mse_dyn.item() * len(batch)
        logs_train['loss_dyn'] += loss_dyn.item() * len(batch)

    # --- logs ---
    # TODO:
    # logs_train['mse_dec'] /= nex_dec
    # logs_train['mse_dyn'] /= nex_dyn
    # logs_train['loss_dyn'] /= nex_dyn
    logs_train['loss'] = logs_train['mse_dec'] + logs_train['loss_dyn']
    logger.log('train_epoch', logs_train)
    # checkpoint
    # logger.log('train_epoch.lr', lr)
    logger.checkpoint(model)
    # ------------------------ Test ------------------------
    if opt.test:
        model.eval()
        with torch.no_grad():
            x_pred, _ = model.generate(opt.nt - opt.nt_train)
            score = rmse(x_pred, test_data)
        pb.set_postfix(loss=logs_train['loss'], test=score)
        logger.log('test_epoch.rmse', score)
    else:
        pb.set_postfix(loss=logs_train['loss'])
    # schedule lr
    # if opt.patience > 0 and score < 12:
    #     lr_scheduler.step(score)
    # lr = optimizer.param_groups[0]['lr']
    # if lr <= 1e-5:
예제 #2
0
def train(command=False):
    if command == True:
        #######################################################################################################################
        # Options - CUDA - Random seed
        #######################################################################################################################
        p = configargparse.ArgParser()
        # -- data
        p.add('--datadir', type=str, help='path to dataset', default='data')
        p.add('--dataset', type=str, help='dataset name', default='ncov_confirmed')
        p.add('--nt_train', type=int, help='time for training', default=50)
        p.add('--validation_ratio', type=float, help='validation/train', default=0.1)
        p.add('--start_time', type=int, help='start time for data', default=0)
        p.add('--rescaled', type=str, help='rescaled method', default='d')
        p.add('--normalize_method', type=str, help='normalize method for relation', default='all')

        # -- xp
        p.add('--outputdir', type=str, help='path to save xp', default='default')
        p.add('--xp', type=str, help='xp name', default='stnn')
        # p.add('--dir_auto', type=boolean_string, help='dataset_model', default=True)
        p.add('--xp_auto', type=boolean_string, help='time', default=False)
        p.add('--xp_time', type=boolean_string, help='xp_time', default=True)
        p.add('--auto', type=boolean_string, help='dataset_model + time', default=False)
        # -- model
        p.add('--model', type=str, help='STNN Model', default='default')
        p.add('--mode', type=str, help='STNN mode (default|refine|discover)', default='default')
        p.add('--nz', type=int, help='laten factors size', default=1)
        p.add('--activation', type=str, help='dynamic module activation function (identity|tanh)', default='tanh')
        p.add('--khop', type=int, help='spatial depedencies order', default=1)
        p.add('--nhid', type=int, help='dynamic function hidden size', default=0)
        p.add('--nlayers', type=int, help='dynamic function num layers', default=1)
        p.add('--nhid_de', type=int, help='dynamic function hidden size', default=0)
        p.add('--nlayers_de', type=int, help='dynamic function num layers', default=1)
        p.add('--dropout_f', type=float, help='latent factors dropout', default=.5)
        p.add('--dropout_d', type=float, help='dynamic function dropout', default=.5)
        p.add('--dropout_de', type=float, help='dynamic function dropout', default=.5)
        p.add('--lambd', type=float, help='lambda between reconstruction and dynamic losses', default=.1)
        # -- optim
        p.add('--lr', type=float, help='learning rate', default=3e-3)
        p.add('--optimizer', type=str, help='learning algorithm', default='Adam')
        p.add('--beta1', type=float, default=.0, help='adam beta1')
        p.add('--beta2', type=float, default=.999, help='adam beta2')
        p.add('--eps', type=float, default=1e-9, help='adam eps')
        p.add('--wd', type=float, help='weight decay', default=1e-6)
        p.add('--wd_z', type=float, help='weight decay on latent factors', default=1e-7)
        p.add('--l2_z', type=float, help='l2 between consecutives latent factors', default=0.)
        p.add('--l1_rel', type=float, help='l1 regularization on relation discovery mode', default=0.)
        p.add('--sch_bound', type=float, help='learning rate', default=0.001)
        # -- learning
        p.add('--batch_size', type=int, default=1131, help='batch size')
        p.add('--patience', type=int, default=150, help='number of epoch to wait before trigerring lr decay')
        p.add('--nepoch', type=int, default=10, help='number of epochs to train for')
        p.add('--test', type=boolean_string, default=False, help='test during training')

        # -- gpu
        p.add('--device', type=int, default=-1, help='-1: cpu; > -1: cuda device id')
        # -- seed
        p.add('--manualSeed', type=int, help='manual seed')
        # -- logs
        p.add('--checkpoint_interval', type=int, default=100, help='check point interval')

        # parse
        opt=DotDict(vars(p.parse_args()))
        
    else:
        print('Use Matlab')
        opt = DotDict()
        # -- data
        opt.datadir = 'data'
        opt.dataset = 'ncov_confirmed'
        opt.nt_train = 15
        opt.start_time = 0
        opt.rescaled = 'd'
        opt.normalize_method = 'row'
        # -- xp
        opt.outputdir = 'default'
        opt.xp = 'stnn'
        # opt.dir_auto =  True
        opt.xp_auto =  False
        opt.xp_time =  True
        opt.auto = False
        # -- model
        opt.mode = 'default'
        opt.nz =1
        opt.activation = 'tanh'
        opt.khop = 1
        opt.nhid = 0
        opt.nlayers =1
        opt.dropout_f = .5
        opt.dropout_d = .5
        opt.lambd = .1
        # -- optim
        opt.lr = 3e-3
        opt.beta1 = .0
        opt.beta2 = .999
        opt.eps = 1e-9 
        opt.wd = 1e-6
        opt.wd_z = 1e-7
        opt.l2_z = 0.
        opt.l1_rel = 0.
        opt.sch_bound = 0.017
        # -- learning
        opt.batch_size = 1000
        opt.patience = 150
        opt.nepoch = 100
        opt.test = False
        opt.device = -1
        print(opt)

    # if opt.dir_auto:
    #     opt.outputdir = opt.dataset + "_" + opt.mode 
    if opt.outputdir == 'default':
        opt.outputdir = opt.dataset + "_" + opt.mode
    opt.outputdir = get_dir(opt.outputdir)

    if opt.xp_time:
        opt.xp = opt.xp + "_" + get_time()
    if opt.xp_auto:
        opt.xp = get_time()
    if opt.auto_all:
        opt.outputdir = opt.dataset + "_" + opt.mode 
        opt.xp = get_time()
    opt.mode = opt.mode if opt.mode in ('refine', 'discover') else None
    opt.xp = 'ori-' + opt.xp
    opt.start = time_dir()
    start_st = datetime.datetime.now()
    opt.st = datetime.datetime.now().strftime('%y-%m-%d-%H-%M-%S')
    # cudnn
    if opt.device > -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.device)
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    # seed
    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.device > -1:
        torch.cuda.manual_seed_all(opt.manualSeed)




    #######################################################################################################################
    # Data
    #######################################################################################################################
    # -- load data

    setup, (train_data, test_data, validation_data), relations = get_stnn_data(opt.datadir, opt.dataset, opt.nt_train, opt.khop, opt.start_time, rescaled_method=opt.rescaled, normalize_method=opt.normalize_method, validation_ratio=opt.validation_ratio)
    # relations = relations[:, :, :, 0]
    train_data = train_data.to(device)
    test_data = test_data.to(device)
    relations = relations.to(device)
    for k, v in setup.items():
        opt[k] = v

    # -- train inputs
    t_idx = torch.arange(opt.nt_train, out=torch.LongTensor()).unsqueeze(1).expand(opt.nt_train, opt.nx).contiguous()
    x_idx = torch.arange(opt.nx, out=torch.LongTensor()).expand_as(t_idx).contiguous()
    # dynamic
    idx_dyn = torch.stack((t_idx[1:], x_idx[1:])).view(2, -1).to(device)
    nex_dyn = idx_dyn.size(1)
    # decoder
    idx_dec = torch.stack((t_idx, x_idx)).view(2, -1).to(device)
    nex_dec = idx_dec.size(1)

    #######################################################################################################################
    # Model
    #######################################################################################################################
    if opt.model == 'default':
        model = SaptioTemporalNN(relations, opt.nx, opt.nt_train, opt.nd, opt.nz, opt.mode, opt.nhid, opt.nlayers,
                            opt.dropout_f, opt.dropout_d, opt.activation, opt.periode).to(device)
    elif opt.model == 'GRU':
        model = SaptioTemporalNN_GRU(relations, opt.nx, opt.nt_train, opt.nd, opt.nz, opt.mode, opt.nhid, opt.nlayers,
                            opt.dropout_f, opt.dropout_d, opt.activation, opt.periode).to(device)
    elif opt.model == 'LSTM':
        model = SaptioTemporalNN_LSTM(relations, opt.nx, opt.nt_train, opt.nd, opt.nz, opt.mode, opt.nhid, opt.nlayers,
                            opt.dropout_f, opt.dropout_d, opt.activation, opt.periode).to(device)
    elif opt.model == 'ld':
        model = SaptioTemporalNN_largedecoder(relations, opt.nx, opt.nt_train, opt.nd, opt.nz, opt.mode, opt.nhid, opt.nlayers, opt.nhid_de, opt.nlayers_de, opt.dropout_de,
                            opt.dropout_f, opt.dropout_d, opt.activation, opt.periode).to(device)

    #######################################################################################################################
    # Optimizer
    #######################################################################################################################
    params = [{'params': model.factors_parameters(), 'weight_decay': opt.wd_z},
            {'params': model.dynamic.parameters()},
            {'params': model.decoder.parameters()}]
    if opt.mode in ('refine', 'discover'):
        params.append({'params': model.rel_parameters(), 'weight_decay': 0.})
        
    if opt.optimizer == 'Adam':
        optimizer = optim.Adam(params, lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.wd)
    elif opt.optimizer == 'SGD':
        optimizer = optim.SGD(params, lr=opt.lr, weight_decay=opt.wd)
    elif opt.optimizer == 'Rmsprop':
        optimizer = optim.RMSprop(params, lr=opt.lr, weight_decay=opt.wd)
    elif opt.optimizer == 'Adagrad':
        optimizer = optim.Adagrad(params, lr=opt.lr, weight_decay=opt.wd)

    if opt.patience > 0:
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=opt.patience)


    #######################################################################################################################
    # Logs
    #######################################################################################################################
    logger = Logger(opt.outputdir, opt.xp, opt.checkpoint_interval)
    # with open(os.path.join(opt.outputdir, opt.xp, 'config.json'), 'w') as f:
    #     json.dump(opt, f, sort_keys=True, indent=4)


    #######################################################################################################################
    # Training
    #######################################################################################################################
    lr = opt.lr
    opt.mintest = 1000.0
    if command:
        pb = trange(opt.nepoch)
    else:
        pb = range(opt.nepoch)
    for e in pb:
        # ------------------------ Train ------------------------
        model.train()
        # --- decoder ---
        idx_perm = torch.randperm(nex_dec).to(device)
        batches = idx_perm.split(opt.batch_size)
        logs_train = defaultdict(float)
        for i, batch in enumerate(batches):
            optimizer.zero_grad()
            # data
            input_t = idx_dec[0][batch]
            input_x = idx_dec[1][batch]
            x_target = train_data[input_t, input_x]
            # closure
            x_rec = model.dec_closure(input_t, input_x)
            mse_dec = F.mse_loss(x_rec, x_target)
            # backward
            mse_dec.backward()
            # step
            optimizer.step()
            # log
            # logger.log('train_iter.mse_dec', mse_dec.item())
            logs_train['mse_dec'] += mse_dec.item() * len(batch)
        # --- dynamic ---
        idx_perm = torch.randperm(nex_dyn).to(device)
        batches = idx_perm.split(opt.batch_size)
        for i, batch in enumerate(batches):
            optimizer.zero_grad()
            # data
            input_t = idx_dyn[0][batch]
            input_x = idx_dyn[1][batch]
            # closure
            z_inf = model.factors[input_t, input_x]
            z_pred = model.dyn_closure(input_t - 1, input_x)
            # loss
            mse_dyn = z_pred.sub(z_inf).pow(2).mean()
            loss_dyn = mse_dyn * opt.lambd
            if opt.l2_z > 0:
                loss_dyn += opt.l2_z * model.factors[input_t - 1, input_x].sub(model.factors[input_t, input_x]).pow(2).mean()
            if opt.mode in('refine', 'discover') and opt.l1_rel > 0:
                # rel_weights_tmp = model.rel_weights.data.clone()
                loss_dyn += opt.l1_rel * model.get_relations().abs().mean()
            # backward
            loss_dyn.backward()
            # step
            optimizer.step()
            # clip
            # if opt.mode == 'discover' and opt.l1_rel > 0:  # clip
            #     sign_changed = rel_weights_tmp.sign().ne(model.rel_weights.data.sign())
            #     model.rel_weights.data.masked_fill_(sign_changed, 0)
            # log
            # logger.log('train_iter.mse_dyn', mse_dyn.item())
            logs_train['mse_dyn'] += mse_dyn.item() * len(batch)
            logs_train['loss_dyn'] += loss_dyn.item() * len(batch)

        # --- logs ---
        # TODO:
        logs_train['mse_dec'] /= nex_dec
        logs_train['mse_dyn'] /= nex_dyn
        logs_train['loss_dyn'] /= nex_dyn
        logs_train['loss'] = logs_train['mse_dec'] + logs_train['loss_dyn']
        logger.log('train_epoch', logs_train)
        # checkpoint
        # logger.log('train_epoch.lr', lr)
        logger.checkpoint(model)
        # ------------------------ Test ------------------------
        if opt.test:
            model.eval()
            with torch.no_grad():
                x_pred, _ = model.generate(opt.validation_length)
                score = rmse(x_pred, validation_data)
            if command:
                pb.set_postfix(loss=logs_train['loss'], test=score)
            else:
                print(e, 'loss=', logs_train['loss'], 'test=', score)
            logger.log('test_epoch.rmse', score)
            if opt.mintest > score:
                opt.mintest = score
                # schedule lr
            if opt.patience > 0 and score < opt.sch_bound:
                lr_scheduler.step(score)
            lr = optimizer.param_groups[0]['lr']
            if lr <= 1e-5:
                break
        else:
            if command:
                pb.set_postfix(loss=logs_train['loss'])
            else:
                print(e, 'loss=', logs_train['loss'])
    # ------------------------ Test ------------------------
    model.eval()
    with torch.no_grad():
        x_pred, _ = model.generate(opt.nt - opt.nt_train)
        score_ts = rmse(x_pred, test_data, reduce=False)
        score = rmse(x_pred, test_data)
    # logger.log('test.rmse', score)
    # logger.log('test.ts', {t: {'rmse': scr.item()} for t, scr in enumerate(score_ts)})
    true_pred_data = torch.randn_like(x_pred)
    true_test_data = torch.randn_like(test_data)
    if opt.normalize == 'variance':
        true_pred_data = x_pred * opt.std + opt.mean
        true_test_data = test_data * opt.std + opt.mean
    if opt.normalize == 'min_max':
        true_pred_data = x_pred * (opt.max - opt.min) + opt.mean
        true_test_data = test_data * (opt.max - opt.min) + opt.mean
    true_score = rmse(true_pred_data, true_test_data)
    # print(true_pred_data)

    for i in range(opt.nd):
        d_pred =true_pred_data[:,:, i].cpu().numpy()
        # print(d_pred)
        np.savetxt(os.path.join(get_dir(opt.outputdir), opt.xp, 'true_pred_' + str(i).zfill(3) +  '.txt'), d_pred, delimiter=',')

    opt.test_loss = score
    opt.true_loss = true_score
    logs_train['loss'] = logs_train['mse_dec'] + logs_train['loss_dyn']
    opt.train_loss = logs_train['loss']
    opt.end = time_dir()
    end_st = datetime.datetime.now()
    opt.et = datetime.datetime.now().strftime('%y-%m-%d-%H-%M-%S')
    opt.time = str(end_st - start_st)
    with open(os.path.join(get_dir(opt.outputdir), opt.xp, 'config.json'), 'w') as f:
        json.dump(opt, f, sort_keys=True, indent=4)
    logger.save(model)
예제 #3
0
def main(opt):
    exit_code = 0
    opt.hostname = os.uname()[1]
    opt.running = True
    # cudnn
    if opt.device > -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.device)
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    # seed
    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.device > -1:
        torch.cuda.manual_seed_all(opt.manualSeed)

    ##################################################################################################################
    # Data
    ##################################################################################################################
    # load config
    data_opt = load_config(os.path.join('config', opt.corpus, opt.config, 'corpus.yaml'))
    opt.update(data_opt)
    # load data
    corpus = Corpus(opt.dataroot)
    # split
    trainset, valset, testset = corpus.split(opt.config, opt.min_freq)
    # dataloaders
    # -- train
    train_loader = Iterator(trainset, opt.batch_size, repeat=False, sort_within_batch=True, device=device)
    # -- val
    ts_val = sorted(list(set([ex.timestep for ex in valset])))
    val_loaders = []
    for t in ts_val:
        val_t = Dataset(valset.examples, valset.fields, filter_pred=lambda x: x.timestep == t)
        val_t.sort_key = lambda x: len(x.text)
        val_t_loader = Iterator(val_t, opt.batch_size, train=False, device=device)
        val_loaders.append((t, val_t_loader))
    val_loaders = OrderedDict(val_loaders)
    # -- test
    ts_tests = sorted(list(set([ex.timestep for ex in testset])))
    test_loaders = []
    if opt.config == 'prediction':
        for t, loader in val_loaders.items():
            test_loaders.append((t, loader))
    for t in ts_tests:
        test_t = Dataset(testset.examples, testset.fields, filter_pred=lambda x: x.timestep == t)
        test_t.sort_key = lambda x: len(x.text)
        test_t_loader = Iterator(test_t, opt.batch_size, train=False, device=device)
        test_loaders.append((t, test_t_loader))
    test_loaders = OrderedDict(test_loaders)
    # opt
    opt.ntoken = corpus.vocab_size
    opt.padding_idx = corpus.pad_idx
    opt.nts = max(ex.timestep for ex in trainset) + 1
    opt.nwords = sum(len(ex.text) for ex in trainset)
    # print info
    print('Vocab size: {}'.format(opt.ntoken))
    print(f'{len(trainset)} training documents with {opt.nwords} tokens on {opt.nts} timesteps')

    ##################################################################################################################
    # Model
    ##################################################################################################################
    # load config
    model_opt = load_config(os.path.join('config', opt.corpus, opt.config, '{}.yaml'.format(opt.model)))
    opt.update(model_opt)
    # buid model
    print('Building model...')
    model = lm_factory(opt).to(device)

    ##################################################################################################################
    # Optimizer
    ##################################################################################################################
    optimizer = get_lm_optimizer(model, opt)
    if 'lr_scheduling' in opt:
        if opt.lr_scheduling == 'linear':
            opt.min_lr == 0
            opt.niter = opt.niter_burnin + opt.niter_scheduling
            niter = opt.niter_scheduling
            lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                             lr_lambda=lambda i: max(0, (niter - i) / niter))
        if opt.lr_scheduling == 'reduce_on_plateau':
            assert opt.min_lr > 0
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                      patience=opt.patience, factor=opt.lr_decay)
    else:
        lr_scheduler = None

    ##################################################################################################################
    # Log
    ##################################################################################################################
    opt.xproot = os.path.join(opt.xproot, opt.corpus, opt.config, opt.model, opt.name)
    print(f'New experiment logged at {opt.xproot}')
    logger = Logger(opt.xproot)
    logger.init(opt)

    ##################################################################################################################
    # Trainning
    ##################################################################################################################
    print('Training...')
    pb = trange(opt.niter, ncols=0)
    ppl_eval = None
    finished = False
    itr = -1
    try:
        while not finished:
            for batch in train_loader:
                itr += 1
                model.train()
                # io
                text = batch.text[0][:-1]
                target = batch.text[0][1:]
                timestep = batch.timestep
                # closure
                log_train = model.closure(text, target, timestep, optimizer, opt)
                # eval
                if itr > 0 and itr % opt.niter_checkpoint == 0:
                    model.eval()
                    with torch.no_grad():
                        score, log_val = evaluate_lm(model, val_loaders, opt)
                    # checkpoint
                    log_train['lr'] = optimizer.param_groups[0]['lr']
                    logger.log(itr, 'train', log_train)
                    logger.log(itr, 'val', log_val)
                    logger.checkpoint(itr)
                    # reduce_on_plateau lr scheduling
                    if lr_scheduler and itr >= opt.niter_burnin and opt.lr_scheduling == 'reduce_on_plateau':
                        lr_scheduler.step(score)
                    lr = optimizer.param_groups[0]['lr']
                    if lr < opt.min_lr:
                        finished = True
                        break
                    # progress bar
                    pb.update(opt.niter_checkpoint)
                    pb.set_postfix(chkpt=logger.chkpt, loss=log_train['loss'], score=score, lr=lr)
                # other lr scheduling
                if lr_scheduler and itr >= opt.niter_burnin and opt.lr_scheduling != 'reduce_on_plateau':
                    lr_scheduler.step()
                lr = optimizer.param_groups[0]['lr']
                if lr < opt.min_lr:
                    finished = True
    except KeyboardInterrupt:
        exit_code = 130
    pb.close()
    print('Evaluating...')
    model.eval()
    with torch.no_grad():
        _, log_val = evaluate_lm(model, val_loaders, opt)
        _, results = evaluate_lm(model, test_loaders, opt)
    log_train['lr'] = optimizer.param_groups[0]['lr']
    logger.log(itr, 'train', log_train)
    logger.log(itr, 'val', log_val)
    logger.log(itr, 'test', results)
    logger.checkpoint(itr)
    logger.terminate(model, optimizer)
    return exit_code
예제 #4
0
            y = y.to(device)

            y_pred = evaluate(encoder, decoder, x)
            loss_val = loss_fn(y_pred, y)
            logs_val['mse'] = loss_val.item()

        # logs evaluation
        logger.log('val', logs_val)

    # general information
    tr.set_postfix(train_mse=logs_train['mse'],
                   val_mse=logs_val['mse'],
                   train_rmse=np.sqrt(logs_train['mse']),
                   val_rmse=np.sqrt(logs_val['mse']),
                   lr=lr)
    logger.checkpoint(encoder, 'encoder')
    logger.checkpoint(decoder, 'decoder')

    # learning rate decay
    if opt.patience > 0:
        lr_scheduler1.step(logs_val['mse'])
        lr = encoder_optimizer.param_groups[0]['lr']

        lr_scheduler2.step(logs_val['mse'])
        lr = decoder_optimizer.param_groups[0]['lr']
    if lr <= 1e-5:
        break
"""
#######################################################################################################################
# TEST
#######################################################################################################################