示例#1
0
 def getema(self, period=14, label="ema"):
     tag = "ema:{}".format(period)
     ema = self.analyzer.getIndicator(tag)
     if not ema:
         ema = EMA(self.csdata, {"period": period, "label": label})
         self.analyzer.saveIndicator(tag, ema)
     return ema
示例#2
0
def train(save_pth, use_mixup, mixup_alpha):
    model, criteria = set_model()
    ema = EMA(model, ema_alpha)

    optim, lr_sheduler = set_optimizer(model)

    dltrain = get_train_loader(batch_size=batchsize,
                               num_workers=n_workers,
                               dataset=ds_name,
                               pin_memory=False)

    for e in range(n_epochs):
        tic = time.time()

        loss_avg = train_one_epoch(model, criteria, dltrain, optim, ema,
                                   use_mixup, mixup_alpha)
        lr_sheduler.step()
        acc = evaluate(model, verbose=False)
        ema.apply_shadow()
        acc_ema = evaluate(model, verbose=False)
        ema.restore()

        toc = time.time()
        msg = 'epoch: {}, loss: {:.4f}, lr: {:.4f}, acc: {:.4f}, acc_ema: {:.4f}, time: {:.2f}'.format(
            e, loss_avg,
            list(optim.param_groups)[0]['lr'], acc, acc_ema, toc - tic)
        print(msg)
    save_model(model, save_pth)
    print('done')
    return model
示例#3
0
def MACD(df,t1=12,t2=26):
    if t1 > t2:
        tmp = t1
        t1 = t2
        t2 = tmp
    sma1 = SMA(df,t1)
    sma2 = SMA(df,t2)
    temp = pd.Series(sma1-sma2,index = df.index)
    signal = EMA(temp[t2:],9)
    return temp, signal
示例#4
0
 def getema(self, period=14, label="ema", metric="closed"):
     tag = "ema:{}:{}".format(period, metric)
     ema = self.analyzer.getIndicator(tag)
     if not ema:
         ema = EMA(self.csdata, {
             "period": period,
             "label": label,
             "metric": metric
         })
         self.analyzer.saveIndicator(tag, ema)
     return ema
示例#5
0
 def __init__(self,
              net,
              dataloader=None,
              params=None,
              update_fn=None,
              eval_loader=None,
              test_loader=None,
              has_cuda=True):
     if has_cuda:
         device = torch.device("cuda:0")
     else:
         device = torch.device('cpu')
     self.net = net.to(device)
     self.loader = dataloader
     self.eval_loader = eval_loader
     self.test_loader = test_loader
     self.params = params
     self.device = device
     # self.net.to(device)
     if params is not None:
         n_data, num_classes = params['n_data'], params['num_classes']
         n_eval_data, batch_size = params['n_eval_data'], params[
             'batch_size']
         self.ensemble_pred = torch.zeros((n_data, num_classes),
                                          device=device)
         self.target_pred = torch.zeros((n_data, num_classes),
                                        device=device)
         t_one = torch.ones(())
         self.epoch_pred = t_one.new_empty((n_data, num_classes),
                                           dtype=torch.float32,
                                           device=device)
         self.epoch_mask = t_one.new_empty((n_data),
                                           dtype=torch.float32,
                                           device=device)
         self.train_epoch_loss = \
             t_one.new_empty((n_data // batch_size, 4),
                             dtype=torch.float32, device=device)
         self.train_epoch_acc = \
             t_one.new_empty((n_data // batch_size), dtype=torch.float32,
                             device=device)
         self.eval_epoch_loss = \
             t_one.new_empty((n_eval_data // batch_size, 2),
                             dtype=torch.float32, device=device)
         self.eval_epoch_acc = \
             t_one.new_empty((n_eval_data // batch_size, 2),
                             dtype=torch.float32, device=device)
         self.optimizer = opt.Adam(self.net.parameters())
         self.update_fn = update_fn
         self.ema = EMA(params['polyak_decay'], self.net, has_cuda)
         self.unsup_weight = 0.0
示例#6
0
def test(model, ema, args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    loss = 0
    answers = dict()
    model.eval()

    backup_params = EMA(0)
    for name, param in model.named_parameters():
        if param.requires_grad:
            backup_params.register(name, param.data)
            param.data.copy_(ema.get(name))

    with torch.set_grad_enabled(False):
        for batch in iter(data.dev_iter):
            p1, p2 = model(batch)
            batch_loss = criterion(p1, batch.s_idx) + criterion(
                p2, batch.e_idx)
            loss += batch_loss.item()

            # (batch, c_len, c_len)
            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) *
                    float('-inf')).to(device).tril(-1).unsqueeze(0).expand(
                        batch_size, -1, -1)
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()

            for i in range(batch_size):
                id = batch.id[i]
                answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1]
                answer = ' '.join(
                    [data.WORD.vocab.itos[idx] for idx in answer])
                answers[id] = answer

        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data.copy_(backup_params.get(name))

    with open(args.prediction_file, 'w', encoding='utf-8') as f:
        print(json.dumps(answers), file=f)

    results = evaluate.main(args)
    return loss, results['exact_match'], results['f1']
示例#7
0
def train():
    model, criteria_x, criteria_u = set_model()

    n_iters_per_epoch = n_imgs_per_epoch // batchsize
    dltrain_x, dltrain_u = get_train_loader(
        batchsize, n_iters_per_epoch, L=250, K=n_guesses
    )
    lb_guessor = LabelGuessor(model, T=temperature)
    mixuper = MixUp(mixup_alpha)

    ema = EMA(model, ema_alpha)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    n_iters_per_epoch = n_imgs_per_epoch // batchsize
    lam_u_epoch = float(lam_u) / n_epoches
    lam_u_once = lam_u_epoch / n_iters_per_epoch

    train_args = dict(
        model=model,
        criteria_x=criteria_x,
        criteria_u=criteria_u,
        optim=optim,
        ema=ema,
        wd = 1 - weight_decay * lr,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        lb_guessor=lb_guessor,
        mixuper=mixuper,
        lambda_u=0,
        lambda_u_once=lam_u_once,
    )
    best_acc = -1
    print('start to train')
    for e in range(n_epoches):
        model.train()
        print('epoch: {}'.format(e))
        train_args['lambda_u'] = e * lam_u_epoch
        train_one_epoch(**train_args)
        torch.cuda.empty_cache()

        acc = evaluate(ema)
        best_acc = acc if best_acc < acc else best_acc
        log_msg = [
            'epoch: {}'.format(e),
            'acc: {:.4f}'.format(acc),
            'best_acc: {:.4f}'.format(best_acc)]
        print(', '.join(log_msg))
示例#8
0
    def __testEMA(self, strDateTime, strType='O'):
        targetIdx = self.dicDatetime2Idx.get(strDateTime, -1)
        assert not (targetIdx == -1), "Incorrect input datetime."
        assert not (targetIdx not in self.dicTempResultMA_CPU
                    ), "Target datetime not in current temp result."

        strTypeForRaw = self.__convertTypeToRawType(strType)

        lstMA = []
        lstPrice = []
        for idx in xrange(targetIdx - self.timespan + 1, targetIdx + 1):
            lstMA.append(self.dicTempResultMA_CPU[idx][strType])
            lstPrice.append(float(self.dicRawData[idx][strTypeForRaw]))

        from ema import EMA
        ema = EMA(lstMA, lstPrice, self.timespan)
        ema.calculate()
        ema.show()
示例#9
0
def main():

    opts = load_option()
    print(opts)
    train_dataset, test_dataset = load_cifar10(opts, unlabeled_label=opts.unlabeled_label)

    model = ConvLarge(opts.num_classes + 1)
    center = EMA((len(train_dataset), opts.num_classes + 1), opts.alpha)
    if opts.cuda:
        model.to('cuda')
        center.to('cuda')
        
    trainer = Trainer(model, center, train_dataset, test_dataset, opts)

    trainer.validate()
    for epoch in range(opts.max_epoch):
        trainer.train()
        trainer.validate()
示例#10
0
def train():
    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize
    n_iters_all = n_iters_per_epoch * args.n_epochs

    model, criteria_x, criteria_u = set_model()

    dltrain_x, dltrain_u = get_train_loader(args.batchsize,
                                            args.mu,
                                            n_iters_per_epoch,
                                            L=args.n_labeled,
                                            seed=args.seed)
    lb_guessor = LabelGuessor(thresh=args.thr)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for param in model.parameters():
        if len(param.size()) == 1:
            non_wd_params.append(param)
        else:
            wd_params.append(param)
    param_list = [{
        'params': wd_params
    }, {
        'params': non_wd_params,
        'weight_decay': 0
    }]
    optim = torch.optim.SGD(param_list,
                            lr=args.lr,
                            weight_decay=args.weight_decay,
                            momentum=args.momentum,
                            nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(optim,
                                        max_iter=n_iters_all,
                                        warmup_iter=0)

    train_args = dict(
        model=model,
        criteria_x=criteria_x,
        criteria_u=criteria_u,
        optim=optim,
        lr_schdlr=lr_schdlr,
        ema=ema,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        lb_guessor=lb_guessor,
        lambda_u=args.lam_u,
        lambda_c=args.lam_c,
        n_iters=n_iters_per_epoch,
    )
    best_acc = -1
    print('start to train')
    for e in range(args.n_epochs):
        model.train()
        print('epoch: {}'.format(e + 1))
        train_one_epoch(**train_args)
        torch.cuda.empty_cache()

        acc = evaluate(ema)
        best_acc = acc if best_acc < acc else best_acc
        log_msg = [
            'epoch: {}'.format(e), 'acc: {:.4f}'.format(acc),
            'best_acc: {:.4f}'.format(best_acc)
        ]
        print(', '.join(log_msg))

    sort_unlabeled(ema)
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=200,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--train-k',
                        type=int,
                        default=8,
                        help='number of importance weights for training')
    parser.add_argument('--test-k',
                        type=int,
                        default=50,
                        help='number of importance weights for evaluation')
    parser.add_argument('--flow',
                        type=int,
                        default=2,
                        help='number of IAF layers')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='global learning rate')
    parser.add_argument('--enc-lr',
                        type=float,
                        default=1e-4,
                        help='encoder learning rate')
    parser.add_argument('--dec-lr',
                        type=float,
                        default=1e-4,
                        help='decoder learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=-1,
                        help='min learning rate for LR scheduler. '
                        '-1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-3, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=200,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--ts',
                        type=float,
                        default=1,
                        help='log-likelihood weight for ELBO')
    parser.add_argument('--kl',
                        type=float,
                        default=.1,
                        help='KL weight for ELBO')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pvae',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--sigma',
                        type=float,
                        default=.2,
                        help='standard deviation for Gaussian likelihood')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pvae = PVAE(encoder, decoder, classifier, args.sigma, args.cls).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pvae, args.ema_decay, args.ema)

    other_params = [
        param for name, param in pvae.named_parameters()
        if not (name.startswith('decoder.grid_decoder')
                or name.startswith('encoder.grid_encoder'))
    ]
    params = [
        {
            'params': decoder.grid_decoder.parameters(),
            'lr': args.dec_lr
        },
        {
            'params': encoder.grid_encoder.parameters(),
            'lr': args.enc_lr
        },
        {
            'params': other_params
        },
    ]

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pvae' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pvae, 'pvae')
        print_params_count(pvae, 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pvae,
                          val_loader,
                          test_loader,
                          log_dir,
                          eval_args={'iw_samples': args.test_k})
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()
        for (val, idx, mask, y, _, cconv_graph) in train_loader:
            optimizer.zero_grad()
            loss, _, _, loss_info = pvae(val, idx, mask, y, cconv_graph,
                                         args.train_k, args.ts, args.kl)
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            for loss_name, loss_val in loss_info.items():
                loss_breakdown[loss_name] += loss_val

        if scheduler:
            scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pvae': pvae.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
示例#12
0
def train_bidaf(args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = BiDAF(args, data.WORD.vocab.vectors).to(device)

    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(logdir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1

    iterator = data.train_iter
    for i, batch in tqdm(enumerate(iterator)):
        present_epoch = int(iterator.epoch)
        if present_epoch == args.epoch:
            break
        if present_epoch > last_epoch:
            print('epoch:', present_epoch + 1)
        last_epoch = present_epoch

        p1, p2 = model(batch)

        optimizer.zero_grad()
        # print(p1, batch.s_idx)
        # print(p2, batch.e_idx)
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        # print('p1', p1.shape, p1)
        # print('batch.s_idx', batch.s_idx.shape, batch.s_idx.shape)
        # print(loss, batch_loss.item())
        loss += batch_loss.item()
        # print(loss)
        # print(batch_loss.item())
        batch_loss.backward()
        optimizer.step()

        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.update(name, param.data)

        if (i + 1) % args.print_freq == 0:
            dev_loss, dev_exact, dev_f1 = test(model, ema, args, data)
            c = (i + 1) // args.print_freq

            writer.add_scalar('loss/train', loss, c)
            writer.add_scalar('loss/dev', dev_loss, c)
            writer.add_scalar('exact_match/dev', dev_exact, c)
            writer.add_scalar('f1/dev', dev_f1, c)
            print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}'
                  f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}')

            if dev_f1 > max_dev_f1:
                max_dev_f1 = dev_f1
                max_dev_exact = dev_exact
                best_model = copy.deepcopy(model)

            loss = 0
            model.train()

    writer.close()
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')

    return best_model
示例#13
0
def run(p_seed=0, p_epochs=150, p_kernel_size=5, p_logdir="temp"):
    # random number generator seed ------------------------------------------------#
    SEED = p_seed
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    np.random.seed(SEED)

    # kernel size of model --------------------------------------------------------#
    KERNEL_SIZE = p_kernel_size

    # number of epochs ------------------------------------------------------------#
    NUM_EPOCHS = p_epochs

    # file names ------------------------------------------------------------------#
    if not os.path.exists("../logs/%s" % p_logdir):
        os.makedirs("../logs/%s" % p_logdir)
    OUTPUT_FILE = str("../logs/%s/log%03d.out" % (p_logdir, SEED))
    MODEL_FILE = str("../logs/%s/model%03d.pth" % (p_logdir, SEED))

    # enable GPU usage ------------------------------------------------------------#
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda == False:
        print("WARNING: CPU will be used for training.")
        exit(0)

    # data augmentation methods ---------------------------------------------------#
    transform = transforms.Compose([
        RandomRotation(20, seed=SEED),
        transforms.RandomAffine(0, translate=(0.2, 0.2)),
    ])

    # data loader -----------------------------------------------------------------#
    train_dataset = MnistDataset(training=True, transform=transform)
    test_dataset = MnistDataset(training=False, transform=None)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=120,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=100,
                                              shuffle=False)

    # model selection -------------------------------------------------------------#
    if (KERNEL_SIZE == 3):
        model = ModelM3().to(device)
    elif (KERNEL_SIZE == 5):
        model = ModelM5().to(device)
    elif (KERNEL_SIZE == 7):
        model = ModelM7().to(device)

    summary(model, (1, 28, 28))

    # hyperparameter selection ----------------------------------------------------#
    ema = EMA(model, decay=0.999)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                          gamma=0.98)

    # delete result file ----------------------------------------------------------#
    f = open(OUTPUT_FILE, 'w')
    f.close()

    # global variables ------------------------------------------------------------#
    g_step = 0
    max_correct = 0

    # training and evaluation loop ------------------------------------------------#
    for epoch in range(NUM_EPOCHS):
        #--------------------------------------------------------------------------#
        # train process                                                            #
        #--------------------------------------------------------------------------#
        model.train()
        train_loss = 0
        train_corr = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            train_pred = output.argmax(dim=1, keepdim=True)
            train_corr += train_pred.eq(
                target.view_as(train_pred)).sum().item()
            train_loss += F.nll_loss(output, target, reduction='sum').item()
            loss.backward()
            optimizer.step()
            g_step += 1
            ema(model, g_step)
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\tLoss: {:.6f}'.
                      format(epoch, batch_idx * len(data),
                             len(train_loader.dataset),
                             100. * batch_idx / len(train_loader),
                             loss.item()))
        train_loss /= len(train_loader.dataset)
        train_accuracy = 100 * train_corr / len(train_loader.dataset)

        #--------------------------------------------------------------------------#
        # test process                                                             #
        #--------------------------------------------------------------------------#
        model.eval()
        ema.assign(model)
        test_loss = 0
        correct = 0
        total_pred = np.zeros(0)
        total_target = np.zeros(0)
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                total_pred = np.append(total_pred, pred.cpu().numpy())
                total_target = np.append(total_target, target.cpu().numpy())
                correct += pred.eq(target.view_as(pred)).sum().item()
            if (max_correct < correct):
                torch.save(model.state_dict(), MODEL_FILE)
                max_correct = correct
                print("Best accuracy! correct images: %5d" % correct)
        ema.resume(model)

        #--------------------------------------------------------------------------#
        # output                                                                   #
        #--------------------------------------------------------------------------#
        test_loss /= len(test_loader.dataset)
        test_accuracy = 100 * correct / len(test_loader.dataset)
        best_test_accuracy = 100 * max_correct / len(test_loader.dataset)
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%) (best: {:.2f}%)\n'
            .format(test_loss, correct, len(test_loader.dataset),
                    test_accuracy, best_test_accuracy))

        f = open(OUTPUT_FILE, 'a')
        f.write(" %3d %12.6f %9.3f %12.6f %9.3f %9.3f\n" %
                (epoch, train_loss, train_accuracy, test_loss, test_accuracy,
                 best_test_accuracy))
        f.close()

        #--------------------------------------------------------------------------#
        # update learning rate scheduler                                           #
        #--------------------------------------------------------------------------#
        lr_scheduler.step()
示例#14
0
def train_val_model(pipeline_cfg, model_cfg, train_cfg):
    data_pipeline = DataPipeline(**pipeline_cfg)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if model_cfg['cxt_emb_pretrained'] is not None:
        model_cfg['cxt_emb_pretrained'] = torch.load(
            model_cfg['cxt_emb_pretrained'])
    bidaf = BiDAF(word_emb=data_pipeline.word_type.vocab.vectors, **model_cfg)
    ema = EMA(train_cfg['exp_decay_rate'])
    for name, param in bidaf.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, bidaf.parameters())
    optimizer = optim.Adadelta(parameters, lr=train_cfg['lr'])
    criterion = nn.CrossEntropyLoss()

    result = {'best_f1': 0.0, 'best_model': None}

    num_epochs = train_cfg['num_epochs']
    for epoch in range(1, num_epochs + 1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)
        for phase in ['train', 'val']:
            val_answers = dict()
            val_f1 = 0
            val_em = 0
            val_cnt = 0
            val_r = 0

            if phase == 'train':
                bidaf.train()
            else:
                bidaf.eval()
                backup_params = EMA(0)
                for name, param in bidaf.named_parameters():
                    if param.requires_grad:
                        backup_params.register(name, param.data)
                        param.data.copy_(ema.get(name))

            with torch.set_grad_enabled(phase == 'train'):
                for batch_num, batch in enumerate(
                        data_pipeline.data_iterators[phase]):
                    optimizer.zero_grad()
                    p1, p2 = bidaf(batch)
                    loss = criterion(p1, batch.s_idx) + criterion(
                        p2, batch.e_idx)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        for name, param in bidaf.named_parameters():
                            if param.requires_grad:
                                ema.update(name, param.data)
                        if batch_num % train_cfg['batch_per_disp'] == 0:
                            batch_loss = loss.item()
                            print('batch %d: loss %.3f' %
                                  (batch_num, batch_loss))

                    if phase == 'val':
                        batch_size, c_len = p1.size()
                        val_cnt += batch_size
                        ls = nn.LogSoftmax(dim=1)
                        mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1). \
                            unsqueeze(0).expand(batch_size, -1, -1)
                        score = (ls(p1).unsqueeze(2) +
                                 ls(p2).unsqueeze(1)) + mask
                        score, s_idx = score.max(dim=1)
                        score, e_idx = score.max(dim=1)
                        s_idx = torch.gather(s_idx, 1,
                                             e_idx.view(-1, 1)).squeeze()

                        for i in range(batch_size):
                            answer = (s_idx[i], e_idx[i])
                            gt = (batch.s_idx[i], batch.e_idx[i])
                            val_f1 += f1_score(answer, gt)
                            val_em += exact_match_score(answer, gt)
                            val_r += r_score(answer, gt)

            if phase == 'val':
                val_f1 = val_f1 * 100 / val_cnt
                val_em = val_em * 100 / val_cnt
                val_r = val_r * 100 / val_cnt
                print('Epoch %d: %s f1 %.3f | %s em %.3f |  %s rouge %.3f' %
                      (epoch, phase, val_f1, phase, val_em, phase, val_r))
                if val_f1 > result['best_f1']:
                    result['best_f1'] = val_f1
                    result['best_em'] = val_em
                    result['best_model'] = copy.deepcopy(bidaf.state_dict())
                    torch.save(result, train_cfg['ckpoint_file'])
                    # with open(train_cfg['val_answers'], 'w', encoding='utf-8') as f:
                    #     print(json.dumps(val_answers), file=f)
                for name, param in bidaf.named_parameters():
                    if param.requires_grad:
                        param.data.copy_(backup_params.get(name))
images_grid = unnormalize(images_grid)
save_grid(images_grid)
writer.add_image('input_images', images_grid, 0)


############################# Model ####################################
'''
Classification model is initialized here, along with exponential
moving average (EMA) module:
    - model is pushed to gpu if its available.
'''

net = Wide_ResNet(28, 2, 0.3, 10)  # VanillaNet()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
ema = EMA(net, decay=0.9999)


############################## Utils ###################################
'''
Training utils are initialized here, including:
    - CrossEntropyLoss - supervised loss.
    - KLDivLoss - unsupervised consistency loss
    - SGD optimizer
    - CosineAnnealingLR scheduler
    - Evaluation function
'''

criterion_sup = torch.nn.CrossEntropyLoss()
criterion_unsup = torch.nn.KLDivLoss(reduction='none')
optimizer = torch.optim.SGD(
示例#16
0
def train():
    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize
    n_iters_all = n_iters_per_epoch * args.n_epochs #/ args.mu_c
    epsilon = 0.000001

    model, criteria_x, criteria_u = set_model()
    lb_guessor = LabelGuessor(thresh=args.thr)
    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for param in model.parameters():
        if len(param.size()) == 1:
            non_wd_params.append(param)
        else:
            wd_params.append(param)
    param_list = [{'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0)

    dltrain_x, dltrain_u, dltrain_all = get_train_loader(args.batchsize, args.mu, args.mu_c, n_iters_per_epoch, 
                                                         L=args.n_labeled, seed=args.seed)
    train_args = dict(
        model=model,
        criteria_x=criteria_x,
        criteria_u=criteria_u,
        optim=optim,
        lr_schdlr=lr_schdlr,
        ema=ema,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        dltrain_all=dltrain_all,
        lb_guessor=lb_guessor,
    )
    n_labeled = int(args.n_labeled / args.n_classes)
    best_acc, top1 = -1, -1
    results = {'top 1 acc': [], 'best_acc': []}
    
    b_schedule = [args.n_epochs/2, 3*args.n_epochs/4]
    if args.boot_schedule == 1:
        step = int(args.n_epochs/3)
        b_schedule = [step, 2*step]
    elif args.boot_schedule == 2:
        step = int(args.n_epochs/4)
        b_schedule = [step, 2*step, 3*step]
        
    for e in range(args.n_epochs):
        if args.bootstrap > 1 and (e in b_schedule):
            seed = 99
            n_labeled *= args.bootstrap
            name = sort_unlabeled(ema, n_labeled)
            print("Bootstrap at epoch ", e," Name = ",name)
            dltrain_x, dltrain_u, dltrain_all = get_train_loader(args.batchsize, args.mu, args.mu_c, n_iters_per_epoch, 
                                                                 L=10*n_labeled, seed=seed, name=name)
            train_args = dict(
                model=model,
                criteria_x=criteria_x,
                criteria_u=criteria_u,
                optim=optim,
                lr_schdlr=lr_schdlr,
                ema=ema,
                dltrain_x=dltrain_x,
                dltrain_u=dltrain_u,
                dltrain_all=dltrain_all,
                lb_guessor=lb_guessor,
            )

        model.train()
        train_one_epoch(**train_args)
        torch.cuda.empty_cache()

        if args.test == 0 or args.lam_clr < epsilon:
            top1 = evaluate(ema) * 100
        elif args.test == 1:
            memory_data = utils.CIFAR10Pair(root='dataset', train=True, transform=utils.test_transform, download=False)
            memory_data_loader = DataLoader(memory_data, batch_size=args.batchsize, shuffle=False, num_workers=16, pin_memory=True)
            test_data = utils.CIFAR10Pair(root='dataset', train=False, transform=utils.test_transform, download=False)
            test_data_loader = DataLoader(test_data, batch_size=args.batchsize, shuffle=False, num_workers=16, pin_memory=True)
            c = len(memory_data.classes) #10
            top1 = test(model, memory_data_loader, test_data_loader, c, e)
            
        best_acc = top1 if best_acc < top1 else best_acc

        results['top 1 acc'].append('{:.4f}'.format(top1))
        results['best_acc'].append('{:.4f}'.format(best_acc))
        data_frame = pd.DataFrame(data=results)
        data_frame.to_csv(result_dir + '/' + save_name_pre + '.accuracy.csv', index_label='epoch')

        log_msg = [
            'epoch: {}'.format(e + 1),
            'top 1 acc: {:.4f}'.format(top1),
            'best_acc: {:.4f}'.format(best_acc)]
        print(', '.join(log_msg))
示例#17
0
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset, help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=8e-5,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=1e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=5e-5,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=7e-5,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=0, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--no-norm-trans',
                        action='store_true',
                        help='if set, use Gaussian posterior without '
                        'transformation')
    parser.add_argument('--plot-interval',
                        type=int,
                        default=1,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=.2,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='smooth_l1',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=-1,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    # squash is off when rescale is off
    parser.add_argument('--squash',
                        dest='squash',
                        action='store_const',
                        const=True,
                        default=True,
                        help='bound the generated time series value '
                        'using tanh')
    parser.add_argument('--no-squash',
                        dest='squash',
                        action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000,
            seq_len=200,
            max_time=1,
            poisson_rate=50,
            obs_span_rate=.25,
            save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(train_data,
                               train_time,
                               train_mask,
                               label=None,
                               max_time=max_time,
                               cconv_ref=cconv_ref,
                               overlap_rate=args.overlap,
                               device=device)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    test_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=args.overlap,
                             kernel_size=args.comp,
                             norm=True).to(device)
    encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device)

    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=args.overlap,
                                    kernel_size=args.comp,
                                    norm=True).to(device)
    critic = ConvCritic(critic_cconv, nz).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'toy-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)

        for ((val, idx, mask, _, cconv_graph),
             (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan(
                val, idx, mask, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            critic_optimizer.step()

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if ema:
                ema.apply()
                visualizer.plot(epoch)
                ema.restore()
            else:
                visualizer.plot(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
示例#18
0
    filter(lambda p: p.requires_grad, model.parameters()))
# optimizer = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()))

if os.path.isfile(args.resume):
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume)
    args.start_epoch = checkpoint['epoch']
    # best_prec1 = checkpoint['best_prec1']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])  # TODO ?
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.resume, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(args.resume))

ema = EMA(0.999)
for name, param in model.named_parameters():
    if param.requires_grad:
        ema.register(name, param.data)

print(model)
print('parameters-----')
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data.size())

if args.test == 1:
    print('Test mode')
    test(model, test_data)
else:
    print('Train mode')
示例#19
0
 def __init__(self, period):
     self.value = None
     self.last = None
     self.ema_u = EMA(period)
     self.ema_d = EMA(period)
     self.tbl = None
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=500,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=3e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=1e-4,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=1.5e-4,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=1,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=1,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='mse',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--dis-ch',
                        default=None,
                        help='discriminator architecture. Use encoder '
                        'architecture if unspecified.')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--trans',
                        type=int,
                        default=2,
                        help='number of encoder layers')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    if args.dis_ch is None:
        args.dis_ch = args.enc_ch

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]
    dis_channels = [int(c) for c in args.dis_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)
    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pbigan = PBiGAN(encoder, decoder, classifier,
                    ae_loss=args.aeloss).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=overlap,
                                    kernel_size=args.comp,
                                    norm=args.cconv_norm).to(device)
    critic_embed = 32
    critic = ConvCritic(critic_cconv, nz, dis_channels,
                        critic_embed).to(device)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           betas=(0, .999),
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  betas=(0, .999),
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pbigan, 'pbigan')
        print_params_count(critic, 'critic')
        print_params_count([pbigan, critic], 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir)
    start = time.time()
    epoch_start = start

    batch_size = args.batch_size

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()

        if epoch >= 40:
            args.cls = 200

        for ((val, idx, mask, y, _, cconv_graph),
             (_, idx_t, mask_t, _, index, _)) in zip(train_loader,
                                                     time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan(
                val, idx, mask, y, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            # Don't need pbigan.requires_grad_(False);
            # critic takes as input only the detached tensors.
            real = critic(cconv_graph, batch_size, z_enc.detach())
            detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x
                              for i, x in enumerate(cconv_graph_gen)]
            fake = critic(detached_graph, batch_size, z_gen.detach())

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward()
            critic_optimizer.step()

            for p in critic.parameters():
                p.requires_grad_(False)
            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls +
                    mmd_loss * args.mmd)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            for p in critic.parameters():
                p.requires_grad_(True)

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['CLS'] += cls_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
    summ = np.vstack((summ, s))

rsi_returns = pd.DataFrame(
    summ, columns=['Security', 'Return', 'Holding Period', 'Purchase Date'])
rsi_returns.to_csv(
    r'Z:\School\Sem 4\Research Methodology\Paper\RSI_Returns.csv', index=False)

#EMA (suited for 13-34 days. For other periods, change the period_1 and period_2 attribute while creating EMA object.

for s in samples:
    df = pd.read_csv(r'Z:\School\Sem 4\Research Methodology\Paper\Sample\\' +
                     s)
    df['Date'] = pd.to_datetime(df['Date'])
    df = df[df['Date'] >= start_date]
    df = df[df['Date'] <= end_date]
    ema_obj = EMA(period_1=13, period_2=34)
    ema_obj.generate(df, calc_returns=True)
    summ = ema_obj.get_summary()
    summaries.append(
        np.hstack(([[s[:-7]] for _ in range(summ.shape[0])], summ)))

summ = summaries[0]
for i, s in enumerate(summaries):
    if i == 0:
        continue
    summ = np.vstack((summ, s))

ema_returns = pd.DataFrame(
    summ, columns=['Security', 'Return', 'Holding Period', 'Purchase Date'])
ema_returns.to_csv(
    r'Z:\School\Sem 4\Research Methodology\Paper\EMA_13_34_Returns.csv',