Exemple #1
0
def _meta_update(model, total_grad, opt, task, maml_batchsize, clip_grad):
    '''
        Aggregate the gradients in total_grad
        Update the initialization in model
    '''

    model['ebd'].train()
    model['clf'].train()
    support, query = task
    XS = model['ebd'](support)
    pred = model['clf'](XS)
    loss = torch.sum(pred)  # this doesn't matter

    # aggregate the gradients (skip nan)
    avg_grad = {
            'ebd': {key: sum(g[key] for g in total_grad['ebd'] if
                        not torch.sum(torch.isnan(g[key])) > 0)\
                    for key in total_grad['ebd'][0].keys()},
            'clf': {key: sum(g[key] for g in total_grad['clf'] if
                        not torch.sum(torch.isnan(g[key])) > 0)\
                    for key in total_grad['clf'][0].keys()}
            }

    # register a hook on each parameter in the model that replaces
    # the current dummy grad with the meta gradiets
    hooks = []
    for model_name in avg_grad.keys():
        for key, value in model[model_name].named_parameters():
            if not value.requires_grad:
                continue

            def get_closure():
                k = key
                n = model_name

                def replace_grad(grad):
                    return avg_grad[n][k] / maml_batchsize

                return replace_grad

            hooks.append(value.register_hook(get_closure()))

    opt.zero_grad()
    loss.backward()

    ebd_grad = get_norm(model['ebd'])
    clf_grad = get_norm(model['clf'])
    if clip_grad is not None:
        nn.utils.clip_grad_value_(grad_param(model, ['ebd', 'clf']), clip_grad)

    opt.step()

    for h in hooks:
        # remove the hooks before the next training phase
        h.remove()

    total_grad['ebd'] = []
    total_grad['clf'] = []

    return ebd_grad, clf_grad
Exemple #2
0
def test_one(task, fast, args):
    '''
        Evaluate the model on one sampled task. Return the accuracy.
    '''
    support, query = task
    YS, YQ = fast['clf'].reidx_y(support['label'], query['label'])

    fast['ebd'].train()
    fast['clf'].train()

    opt = torch.optim.SGD(grad_param(fast, ['ebd', 'clf']),
                          lr=args.maml_stepsize)

    for i in range(args.maml_innersteps):
        XS = fast['ebd'](support)
        pred = fast['clf'](XS)
        loss = F.cross_entropy(pred, YS)

        opt.zero_grad()
        loss.backward()
        opt.step()

    fast['ebd'].eval()
    fast['clf'].eval()

    XQ = fast['ebd'](query)
    pred = fast['clf'](XQ)
    acc = torch.mean((torch.argmax(pred, dim=1) == YQ).float()).item()

    return acc
def train_one_fomaml(task, fast, args, total_grad):
    '''
        Update the fast_model based on the support set.
        Return the gradient w.r.t. initializations over the query set
        First order MAML
    '''
    support, query = task

    # map class label into 0,...,num_classes-1
    YS, YQ = fast['clf'].reidx_y(support['label'], query['label'])

    opt = torch.optim.SGD(grad_param(fast, ['ebd', 'clf']),
                          lr=args.maml_stepsize)

    fast['ebd'].train()
    fast['clf'].train()

    # fast adaptation
    for i in range(args.maml_innersteps):
        opt.zero_grad()

        XS = fast['ebd'](support)
        acc, loss = fast['clf'](XS, YS)

        loss.backward()

        opt.step()

    # forward on the query, to get meta loss
    XQ = fast['ebd'](query)
    acc, loss = fast['clf'](XQ, YQ)

    loss.backward()

    grads_ebd = {name: p.grad for (name, p) in named_grad_param(fast, ['ebd'])\
                 if p.grad is not None}  # pooler does not have grad in Bert
    grads_clf = {name: p.grad for (name, p) in named_grad_param(fast, ['clf'])}

    total_grad['ebd'].append(grads_ebd)
    total_grad['clf'].append(grads_clf)

    return
def train_one(task, model, opt, args, grad):
    '''
        Train the model on one sampled task.
    '''
    model['ebd'].train()
    if not args.classifier == 'nn':
        model['clf'].train()
        opt.zero_grad()

    support, query = task

    # Embedding the document
    XS = model['ebd'](support)
    YS = support['label']

    XQ = model['ebd'](query)
    YQ = query['label']

    # Apply the classifier
    _, loss = model['clf'](XS, YS, XQ, YQ)

    print('loss: ', loss)

    if loss is not None:
        loss.backward()

    if torch.isnan(loss):
        # do not update the parameters if the gradient is nan
        print("NAN detected")
        print(model['clf'].lam, model['clf'].alpha, model['clf'].beta)
        return

    if args.clip_grad is not None:
        nn.utils.clip_grad_value_(grad_param(model, ['ebd', 'clf']),
                                  args.clip_grad)

    if args.classifier != 'nn':
        grad['clf'].append(get_norm(model['clf']))
    grad['ebd'].append(get_norm(model['ebd']))

    if args.classifier != 'nn':
        opt.step()
Exemple #5
0
def train(train_data, val_data, model, args):
    '''
        Train the model
        Use val_data to do early stopping
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(os.path.join(
                                  os.path.curdir,
                                  "tmp-runs",
                                  str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    optG = torch.optim.Adam(grad_param(model, ['G', 'clf']), lr=args.lr_g)
    optD = torch.optim.Adam(grad_param(model, ['D']), lr=args.lr_d)

    if args.lr_scheduler == 'ReduceLROnPlateau':
        schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optG, 'max', patience=args.patience//2, factor=0.1, verbose=True)
        schedulerD = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optD, 'max', patience=args.patience // 2, factor=0.1, verbose=True)

    elif args.lr_scheduler == 'ExponentialLR':
        schedulerG = torch.optim.lr_scheduler.ExponentialLR(optG, gamma=args.ExponentialLR_gamma)
        schedulerD = torch.optim.lr_scheduler.ExponentialLR(optD, gamma=args.ExponentialLR_gamma)



    print("{}, Start training".format(
        datetime.datetime.now()), flush=True)

    # train_gen = ParallelSampler(train_data, args, args.train_episodes)
    train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes)
    val_gen = ParallelSampler_Test(val_data, args, args.val_episodes)

    # sampled_classes, source_classes = task_sampler(train_data, args)
    for ep in range(args.train_epochs):

        sampled_classes, source_classes = task_sampler(train_data, args)

        train_gen = ParallelSampler(train_data, args, sampled_classes, source_classes, args.train_episodes)

        sampled_tasks = train_gen.get_epoch()

        grad = {'clf': [], 'G': [], 'D': []}

        if not args.notqdm:
            sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes,
                    ncols=80, leave=False, desc=colored('Training on train',
                        'yellow'))
        d_acc = 0
        for task in sampled_tasks:
            if task is None:
                break
            d_acc += train_one(task, model, optG, optD, args, grad)

        d_acc = d_acc / args.train_episodes

        print("---------------ep:" + str(ep) + " d_acc:" + str(d_acc) + "-----------")

        if ep % 10 == 0:

            acc, std, _ = test(train_data, model, args, args.val_episodes, False,
                            train_gen_val.get_epoch())
            print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                datetime.datetime.now(),
                "ep", ep,
                colored("train", "red"),
                colored("acc:", "blue"), acc, std,
                ), flush=True)

        # Evaluate validation accuracy
        cur_acc, cur_std, _ = test(val_data, model, args, args.val_episodes, False,
                                val_gen.get_epoch())
        print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "
               "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
               datetime.datetime.now(),
               "ep", ep,
               colored("val  ", "cyan"),
               colored("acc:", "blue"), cur_acc, cur_std,
               colored("train stats", "cyan"),
               colored("G_grad:", "blue"), np.mean(np.array(grad['G'])),
               colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
               ), flush=True)

        # Update the current best model if val acc is better
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_path = os.path.join(out_dir, str(ep))

            # save current model
            print("{}, Save cur best model to {}".format(
                datetime.datetime.now(),
                best_path))

            torch.save(model['G'].state_dict(), best_path + '.G')
            torch.save(model['D'].state_dict(), best_path + '.D')
            torch.save(model['clf'].state_dict(), best_path + '.clf')

            sub_cycle = 0
        else:
            sub_cycle += 1

        # Break if the val acc hasn't improved in the past patience epochs
        if sub_cycle == args.patience:
            break

        if args.lr_scheduler == 'ReduceLROnPlateau':
            schedulerG.step(cur_acc)
            schedulerD.step(cur_acc)

        elif args.lr_scheduler == 'ExponentialLR':
            schedulerG.step()
            schedulerD.step()

    print("{}, End of training. Restore the best weights".format(
            datetime.datetime.now()),
            flush=True)

    # restore the best saved model
    model['G'].load_state_dict(torch.load(best_path + '.G'))
    model['D'].load_state_dict(torch.load(best_path + '.D'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(os.path.join(
                                      os.path.curdir,
                                      "saved-runs",
                                      str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(
            datetime.datetime.now(),
            best_path), flush=True)

        torch.save(model['G'].state_dict(), best_path + '.G')
        torch.save(model['D'].state_dict(), best_path + '.D')
        torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return
Exemple #6
0
def train(train_data, val_data, model, args):
    '''
        Train the model (obviously~)
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(
        os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    opt = torch.optim.Adam(grad_param(model, ['ebd', 'clf']), lr=args.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
                                                           'max',
                                                           patience=5,
                                                           factor=0.1,
                                                           verbose=True)

    # clone the original model
    fast_model = {
        'ebd': copy.deepcopy(model['ebd']),
        'clf': copy.deepcopy(model['clf']),
    }

    print("{}, Start training".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')))

    train_gen = ParallelSampler(train_data, args,
                                args.train_episodes * args.maml_batchsize)
    val_gen = ParallelSampler(val_data, args, args.val_episodes)
    for ep in range(args.train_epochs):
        sampled_tasks = train_gen.get_epoch()

        meta_grad_dict = {'clf': [], 'ebd': []}

        train_episodes = range(args.train_episodes)
        if not args.notqdm:
            train_episodes = tqdm(train_episodes,
                                  ncols=80,
                                  leave=False,
                                  desc=colored('Training on train', 'yellow'))

        for _ in train_episodes:
            # update the initialization based on a batch of tasks
            total_grad = {'ebd': [], 'clf': []}

            for _ in range(args.maml_batchsize):
                task = next(sampled_tasks)

                # clone the current initialization
                _copy_weights(model['ebd'], fast_model['ebd'])
                _copy_weights(model['clf'], fast_model['clf'])

                # get the meta gradient
                train_one(task, fast_model, args, total_grad)

            ebd_grad, clf_grad = _meta_update(model, total_grad, opt, task,
                                              args.maml_batchsize,
                                              args.clip_grad)
            meta_grad_dict['ebd'].append(ebd_grad)
            meta_grad_dict['clf'].append(clf_grad)

        # evaluate training accuracy
        if ep % 10 == 0:
            acc, std = test(train_data, model, args,
                            args.train_episodes * args.maml_batchsize, False,
                            train_gen.get_epoch())
            print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                "ep",
                ep,
                colored("train", "red"),
                colored("acc:", "blue"),
                acc,
                std,
            ),
                  flush=True)

        # evaluate validation accuracy
        cur_acc, cur_std = test(val_data, model, args, args.val_episodes,
                                False, val_gen.get_epoch())
        print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} "
               "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
                   datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                   "ep", ep, colored("val  ", "cyan"), colored("acc:", "blue"),
                   cur_acc, cur_std, colored("train stats", "cyan"),
                   colored("ebd_grad:", "blue"),
                   np.mean(np.array(meta_grad_dict['ebd'])),
                   colored("clf_grad:", "blue"),
                   np.mean(np.array(meta_grad_dict['clf']))),
              flush=True)

        # Update the current best model if val acc is better
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_path = os.path.join(out_dir, str(ep))

            # save current model
            print("{}, Save cur best model to {}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                best_path))

            torch.save(model['ebd'].state_dict(), best_path + '.ebd')
            torch.save(model['clf'].state_dict(), best_path + '.clf')

            sub_cycle = 0
        else:
            sub_cycle += 1

        # Break if the val acc hasn't improved in the past patience epochs
        if sub_cycle == args.patience:
            break

    print("{}, End of training. Restore the best weights".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')))

    # restore the best saved model
    model['ebd'].load_state_dict(torch.load(best_path + '.ebd'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "saved-runs",
                         str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            best_path),
              flush=True)

        torch.save(model['ebd'].state_dict(), best_path + '.ebd')
        torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return
Exemple #7
0
def train_one(task, fast, args, total_grad):
    '''
        Update the fast_model based on the support set.
        Return the gradient w.r.t. initializations over the query set
    '''
    support, query = task

    # map class label into 0,...,num_classes-1
    YS, YQ = fast['clf'].reidx_y(support['label'], query['label'])

    fast['ebd'].train()
    fast['clf'].train()

    # get weights
    fast_weights = {
        'ebd':
        OrderedDict((name, param)
                    for (name, param) in named_grad_param(fast, ['ebd'])),
        'clf':
        OrderedDict((name, param)
                    for (name, param) in named_grad_param(fast, ['clf'])),
    }

    num_ebd_w = len(fast_weights['ebd'])
    num_clf_w = len(fast_weights['clf'])

    # fast adaptation
    for i in range(args.maml_innersteps):
        if i == 0:
            XS = fast['ebd'](support)
            pred = fast['clf'](XS)
            loss = F.cross_entropy(pred, YS)
            grads = torch.autograd.grad(loss,
                                        grad_param(fast, ['ebd', 'clf']),
                                        create_graph=True)

        else:
            XS = fast['ebd'](support, fast_weights['ebd'])
            pred = fast['clf'](XS, weights=fast_weights['clf'])
            loss = F.cross_entropy(pred, YS)
            grads = torch.autograd.grad(loss,
                                        itertools.chain(
                                            fast_weights['ebd'].values(),
                                            fast_weights['clf'].values()),
                                        create_graph=True)

        if args.maml_firstorder:
            grads = tuple([g.detach() for g in list(grads)])

        # update fast weight
        fast_weights['ebd'] = OrderedDict(
            (name, param - args.maml_stepsize * grad)
            for ((name, param),
                 grad) in zip(fast_weights['ebd'].items(), grads[:num_ebd_w]))

        fast_weights['clf'] = OrderedDict(
            (name, param - args.maml_stepsize * grad)
            for ((name, param),
                 grad) in zip(fast_weights['clf'].items(), grads[num_ebd_w:]))

    # forward on the query, to get meta loss
    XQ = fast['ebd'](query, fast_weights['ebd'])
    pred = fast['clf'](XQ, weights=fast_weights['clf'])
    loss = F.cross_entropy(pred, YQ)

    grads = torch.autograd.grad(loss, grad_param(fast, ['ebd', 'clf']))

    grads_ebd = {
        name: g
        for ((name, _),
             g) in zip(named_grad_param(fast, ['ebd']), grads[:num_ebd_w])
    }
    grads_clf = {
        name: g
        for ((name, _),
             g) in zip(named_grad_param(fast, ['clf']), grads[num_ebd_w:])
    }

    total_grad['ebd'].append(grads_ebd)
    total_grad['clf'].append(grads_clf)

    return
Exemple #8
0
def train(train_data, val_data, test_data, model, class_names, criterion,
          args):
    '''
        Train the model
        Use val_data to do early stopping
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(
        os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    if args.STS == True:
        classes_sample_p, example_prob_metrix = pre_calculate(
            train_data, class_names, model['G'], args)
    else:
        classes_sample_p, example_prob_metrix = None, None

    optG = torch.optim.Adam(grad_param(model, ['G']),
                            lr=args.meta_lr,
                            weight_decay=args.weight_decay)
    # optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr)
    # optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr)

    if args.lr_scheduler == 'ReduceLROnPlateau':
        schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True)
        # schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optCLF, 'max', patience=args.patience // 2, factor=0.1, verbose=True)

    elif args.lr_scheduler == 'ExponentialLR':
        schedulerG = torch.optim.lr_scheduler.ExponentialLR(
            optG, gamma=args.ExponentialLR_gamma)
        # schedulerCLF = torch.optim.lr_scheduler.ExponentialLR(optCLF, gamma=args.ExponentialLR_gamma)

    print("{}, Start training".format(datetime.datetime.now()), flush=True)

    # sampled_classes, source_classes = task_sampler(train_data, args)
    acc = 0
    loss = 0
    for ep in range(args.train_epochs):
        ep_loss = 0
        for _ in range(args.train_episodes):

            sampled_classes, source_classes = task_sampler(
                train_data, args, classes_sample_p)

            train_gen = SerialSampler(train_data, args, sampled_classes,
                                      source_classes, 1, example_prob_metrix)

            sampled_tasks = train_gen.get_epoch()

            grad = {'clf': [], 'G': []}

            if not args.notqdm:
                sampled_tasks = tqdm(sampled_tasks,
                                     total=train_gen.num_episodes,
                                     ncols=80,
                                     leave=False,
                                     desc=colored('Training on train',
                                                  'yellow'))

            for task in sampled_tasks:
                if task is None:
                    break
                q_loss, q_acc = train_one(task, class_names, model, optG,
                                          criterion, args, grad)
                acc += q_acc
                loss = loss + q_loss
                ep_loss = ep_loss + q_loss

        ep_loss = ep_loss / args.train_episodes

        optG.zero_grad()
        ep_loss.backward()
        optG.step()

        if ep % 100 == 0:
            print("{}:".format(colored('--------[TRAIN] ep', 'blue')) +
                  str(ep) + ", loss:" + str(q_loss.item()) + ", acc:" +
                  str(q_acc.item()) + "-----------")

        test_count = 100
        # if (ep % test_count == 0) and (ep != 0):
        if (ep % test_count == 0):
            acc = acc / args.train_episodes / test_count
            loss = loss / args.train_episodes / test_count
            print("{}:".format(colored('--------[TRAIN] ep', 'blue')) +
                  str(ep) + ", mean_loss:" + str(loss.item()) + ", mean_acc:" +
                  str(acc.item()) + "-----------")

            net = copy.deepcopy(model)
            # acc, std = test(train_data, class_names, optG, net, criterion, args, args.test_epochs, False)
            # print("[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
            #     datetime.datetime.now(),
            #     "ep", ep,
            #     colored("train", "red"),
            #     colored("acc:", "blue"), acc, std,
            #     ), flush=True)
            acc = 0
            loss = 0

            # Evaluate test accuracy
            cur_acc, cur_std = test(test_data, class_names, optG, net,
                                    criterion, args, args.test_epochs, False)
            print(
                ("[TEST] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, ").
                format(
                    datetime.datetime.now(),
                    "ep",
                    ep,
                    colored("test  ", "cyan"),
                    colored("acc:", "blue"),
                    cur_acc,
                    cur_std,
                    # colored("train stats", "cyan"),
                    # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])),
                    # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
                ),
                flush=True)

            # Evaluate validation accuracy
            cur_acc, cur_std = test(val_data, class_names, optG, net,
                                    criterion, args, args.test_epochs, False)
            print(
                ("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, ").
                format(
                    datetime.datetime.now(),
                    "ep",
                    ep,
                    colored("val  ", "cyan"),
                    colored("acc:", "blue"),
                    cur_acc,
                    cur_std,
                    # colored("train stats", "cyan"),
                    # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])),
                    # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
                ),
                flush=True)

            # Update the current best model if val acc is better
            if cur_acc > best_acc:
                best_acc = cur_acc
                best_path = os.path.join(out_dir, str(ep))

                # save current model
                print("{}, Save cur best model to {}".format(
                    datetime.datetime.now(), best_path))

                torch.save(model['G'].state_dict(), best_path + '.G')
                # torch.save(model['G2'].state_dict(), best_path + '.G2')
                # torch.save(model['clf'].state_dict(), best_path + '.clf')

                sub_cycle = 0
            else:
                sub_cycle += 1

            # Break if the val acc hasn't improved in the past patience epochs
            if sub_cycle == args.patience:
                break

            if args.lr_scheduler == 'ReduceLROnPlateau':
                schedulerG.step(cur_acc)
                # schedulerCLF.step(cur_acc)

            elif args.lr_scheduler == 'ExponentialLR':
                schedulerG.step()
                # schedulerCLF.step()

    print("{}, End of training. Restore the best weights".format(
        datetime.datetime.now()),
          flush=True)

    # restore the best saved model
    model['G'].load_state_dict(torch.load(best_path + '.G'))
    # model['G2'].load_state_dict(torch.load(best_path + '.G2'))
    # model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "saved-runs",
                         str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(datetime.datetime.now(),
                                                 best_path),
              flush=True)

        torch.save(model['G'].state_dict(), best_path + '.G')
        # torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return optG
Exemple #9
0
def train(train_data, val_data, model, class_names, args):
    '''
        Train the model
        Use val_data to do early stopping
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(
        os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    optG = torch.optim.Adam(grad_param(model, ['G']), lr=args.meta_lr)
    optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr)
    optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr)

    if args.lr_scheduler == 'ReduceLROnPlateau':
        schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True)
        schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optCLF,
            'max',
            patience=args.patience // 2,
            factor=0.1,
            verbose=True)

    elif args.lr_scheduler == 'ExponentialLR':
        schedulerG = torch.optim.lr_scheduler.ExponentialLR(
            optG, gamma=args.ExponentialLR_gamma)
        schedulerCLF = torch.optim.lr_scheduler.ExponentialLR(
            optCLF, gamma=args.ExponentialLR_gamma)

    print("{}, Start training".format(datetime.datetime.now()), flush=True)

    # train_gen = ParallelSampler(train_data, args, args.train_episodes)
    # train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes)
    # val_gen = ParallelSampler_Test(val_data, args, args.val_episodes)

    # sampled_classes, source_classes = task_sampler(train_data, args)
    acc = 0
    loss = 0
    for ep in range(args.train_epochs):

        sampled_classes, source_classes = task_sampler(train_data, args)
        class_names_dict = {}
        class_names_dict['label'] = class_names['label'][sampled_classes]
        class_names_dict['text'] = class_names['text'][sampled_classes]
        class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
        class_names_dict['is_support'] = False

        train_gen = ParallelSampler(train_data, args, sampled_classes,
                                    source_classes, args.train_episodes)

        sampled_tasks = train_gen.get_epoch()
        class_names_dict = utils.to_tensor(class_names_dict,
                                           args.cuda,
                                           exclude_keys=['is_support'])

        grad = {'clf': [], 'G': []}

        if not args.notqdm:
            sampled_tasks = tqdm(sampled_tasks,
                                 total=train_gen.num_episodes,
                                 ncols=80,
                                 leave=False,
                                 desc=colored('Training on train', 'yellow'))

        for task in sampled_tasks:
            if task is None:
                break
            q_loss, q_acc = train_one(task, class_names_dict, model, optG,
                                      optG2, optCLF, args, grad)
            acc += q_acc
            loss += q_loss

        if ep % 100 == 0:
            print("--------[TRAIN] ep:" + str(ep) + ", loss:" +
                  str(q_loss.item()) + ", acc:" + str(q_acc.item()) +
                  "-----------")

        if (ep % 200 == 0) and (ep != 0):
            acc = acc / args.train_episodes / 200
            loss = loss / args.train_episodes / 200
            print("--------[TRAIN] ep:" + str(ep) + ", mean_loss:" +
                  str(loss.item()) + ", mean_acc:" + str(acc.item()) +
                  "-----------")

            net = copy.deepcopy(model)
            acc, std = test(train_data, class_names, optG, optCLF, net, args,
                            args.test_epochs, False)
            print(
                "[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                    datetime.datetime.now(),
                    "ep",
                    ep,
                    colored("train", "red"),
                    colored("acc:", "blue"),
                    acc,
                    std,
                ),
                flush=True)
            acc = 0
            loss = 0

            # Evaluate validation accuracy
            cur_acc, cur_std = test(val_data, class_names, optG, optCLF, net,
                                    args, args.test_epochs, False)
            print(("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "
                   "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
                       datetime.datetime.now(),
                       "ep",
                       ep,
                       colored("val  ", "cyan"),
                       colored("acc:", "blue"),
                       cur_acc,
                       cur_std,
                       colored("train stats", "cyan"),
                       colored("G_grad:", "blue"),
                       np.mean(np.array(grad['G'])),
                       colored("clf_grad:", "blue"),
                       np.mean(np.array(grad['clf'])),
                   ),
                  flush=True)

            # Update the current best model if val acc is better
            if cur_acc > best_acc:
                best_acc = cur_acc
                best_path = os.path.join(out_dir, str(ep))

                # save current model
                print("{}, Save cur best model to {}".format(
                    datetime.datetime.now(), best_path))

                torch.save(model['G'].state_dict(), best_path + '.G')
                torch.save(model['G2'].state_dict(), best_path + '.G2')
                torch.save(model['clf'].state_dict(), best_path + '.clf')

                sub_cycle = 0
            else:
                sub_cycle += 1

            # Break if the val acc hasn't improved in the past patience epochs
            if sub_cycle == args.patience:
                break

            if args.lr_scheduler == 'ReduceLROnPlateau':
                schedulerG.step(cur_acc)
                schedulerCLF.step(cur_acc)

            elif args.lr_scheduler == 'ExponentialLR':
                schedulerG.step()
                schedulerCLF.step()

    print("{}, End of training. Restore the best weights".format(
        datetime.datetime.now()),
          flush=True)

    # restore the best saved model
    model['G'].load_state_dict(torch.load(best_path + '.G'))
    model['G2'].load_state_dict(torch.load(best_path + '.G2'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "saved-runs",
                         str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(datetime.datetime.now(),
                                                 best_path),
              flush=True)

        torch.save(model['G'].state_dict(), best_path + '.G')
        torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return optG, optCLF
Exemple #10
0
def train(train_data, val_data, model,args):
    '''
        Train the model
        Use val_data to do early stopping
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(os.path.join(
                                  os.path.curdir,
                                  "tmp-runs",
                                  str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None
    

    opt = torch.optim.Adam(grad_param(model, ['clf']), args.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, 'max', patience=args.patience//2, factor=0.1, verbose=True)

    print("{}, Start training".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True)

    train_gen = ParallelSampler(train_data,args,args.train_episodes)
    train_gen_val = ParallelSampler(train_data,args, args.val_episodes)
    val_gen = ParallelSampler(val_data,args,args.val_episodes)

    for ep in range(args.train_epochs):
        sampled_tasks = train_gen.get_epoch()

        grad = {'clf': [], 'ebd': []}

        sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes,
                    ncols=80, leave=False, desc=colored('Training on train',
                        'yellow'))

        for task in sampled_tasks:
            if task is None:
                break
            train_one(task, model,opt,args,grad)

        acc, std = test(train_data, model,args, args.val_episodes,False,
                        train_gen_val.get_epoch())
        print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            "ep", ep,
            colored("train", "red"),
            colored("acc:", "blue"), acc, std,
            ), flush=True)

        # Evaluate validation accuracy
        cur_acc, cur_std = test(val_data, model,args, args.val_episodes,False,
                                val_gen.get_epoch())
        print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "
               "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
               datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
               "ep", ep,
               colored("val  ", "cyan"),
               colored("acc:", "blue"), cur_acc, cur_std,
               colored("train stats", "cyan"),
               colored("ebd_grad:", "blue"), np.mean(np.array(grad['ebd'])),
               colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
               ), flush=True)

        # Update the current best model if val acc is better
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_path = os.path.join(out_dir, str(ep))

            # save current model
            print("{}, Save cur best model to {}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                best_path))

            torch.save(model['ebd'].state_dict(), best_path + '.ebd')
            torch.save(model['clf'].state_dict(), best_path + '.clf')

            sub_cycle = 0
        else:
            sub_cycle += 1

        # Break if the val acc hasn't improved in the past patience epochs
        if sub_cycle == 20:
            break

    print("{}, End of training. Restore the best weights".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')),
            flush=True)

    # restore the best saved model
    model['ebd'].load_state_dict(torch.load(best_path + '.ebd'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    # save the current model
    out_dir = os.path.abspath(os.path.join(
                                  os.path.curdir,
                                  "saved-runs",
                                  str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_path = os.path.join(out_dir, 'best')

    print("{}, Save best model to {}".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
        best_path), flush=True)

    torch.save(model['ebd'].state_dict(), best_path + '.ebd')
    torch.save(model['clf'].state_dict(), best_path + '.clf')

    return
def train(train_data, val_data, model, args):
    '''
        Train the model
        Use val_data to do early stopping

        Args:
            model (dict): {'ebd': embedding, 'clf': classifier}
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(os.path.join(
                                  os.path.curdir,
                                  "tmp-runs",
                                  str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Write results
    # write_acc_tr = 'acc_base.csv'
    # init_csv(write_acc_tr)
    # write_acc_val = 'val_acc_base.csv'
    # init_csv(write_acc_val)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    # grad_param generates the learnable parameters from the classifier
    params_to_opt = grad_param(model, ['ebd', 'clf'])
    opt = torch.optim.Adam(params_to_opt, lr=args.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, 'max', patience=args.patience//2, factor=0.1, verbose=True)

    print("{}, Start training".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True)

    train_gen = ParallelSampler(train_data, args, args.train_episodes)
    train_gen_val = ParallelSampler(train_data, args, args.val_episodes)
    val_gen = ParallelSampler(val_data, args, args.val_episodes)

    for ep in range(args.train_epochs):
        sampled_tasks = train_gen.get_epoch()

        grad = {'clf': [], 'ebd': []}

        if not args.notqdm:
            sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes,
                    ncols=80, leave=False, desc=colored('Training on train',
                        'yellow'))

        for task in sampled_tasks:
            if task is None:
                break
            train_one(task, model, opt, args, grad)

        if ep % 10 == 0:
            acc, std = test(train_data, model, args, args.val_episodes, False,
                            train_gen_val.get_epoch())
            print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                "ep", ep,
                colored("train", "red"),
                colored("acc:", "blue"), acc, std,
                ), flush=True)

            # write_csv(write_acc_tr, acc, std, ep)

        # Evaluate validation accuracy
        cur_acc, cur_std = test(val_data, model, args, args.val_episodes, False,
                                val_gen.get_epoch())
        print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "
               "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
               datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
               "ep", ep,
               colored("val  ", "cyan"),
               colored("acc:", "blue"), cur_acc, cur_std,
               colored("train stats", "cyan"),
               colored("ebd_grad:", "blue"), np.mean(np.array(grad['ebd'])),
               colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
               ), flush=True)

        # if ep % 10 == 0: write_csv(write_acc_val, cur_acc, cur_std, ep)

        # Update the current best model if val acc is better
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_path = os.path.join(out_dir, str(ep))

            # save current model
            print("{}, Save cur best model to {}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                best_path))

            torch.save(model['ebd'].state_dict(), best_path + '.ebd')
            torch.save(model['clf'].state_dict(), best_path + '.clf')

            sub_cycle = 0
        else:
            sub_cycle += 1

        # Break if the val acc hasn't improved in the past patience epochs
        if sub_cycle == args.patience:
            break

    print("{}, End of training. Restore the best weights".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')),
            flush=True)

    # restore the best saved model
    model['ebd'].load_state_dict(torch.load(best_path + '.ebd'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(os.path.join(
                                      os.path.curdir,
                                      "saved-runs",
                                      str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            best_path), flush=True)

        torch.save(model['ebd'].state_dict(), best_path + '.ebd')
        torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return