Example #1
0
def test(test_data,
         class_names,
         optG,
         model,
         criterion,
         args,
         test_epoch,
         verbose=True):
    '''
        Evaluate the model on a bag of sampled tasks. Return the mean accuracy
        and its std.
    '''
    # model['G'].train()

    acc = []
    for ep in range(test_epoch):

        sampled_classes, source_classes = task_sampler(test_data, args)

        train_gen = SerialSampler(test_data, args, sampled_classes,
                                  source_classes, 1)

        sampled_tasks = train_gen.get_epoch()

        for task in sampled_tasks:
            if task is None:
                break
            q_acc = test_one(task,
                             class_names,
                             model,
                             optG,
                             criterion,
                             args,
                             grad={})
            acc.append(q_acc.cpu().item())

    acc = np.array(acc)

    if verbose:
        if args.embedding != 'mlada':
            print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now(),
                colored("test acc mean", "blue"),
                np.mean(acc),
                colored("test std", "blue"),
                np.std(acc),
            ),
                  flush=True)
        else:
            print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now(),
                colored("test acc mean", "blue"),
                np.mean(acc),
                colored("test std", "blue"),
                np.std(acc),
            ),
                  flush=True)

    return np.mean(acc), np.std(acc)
Example #2
0
def train(train_data, val_data, test_data, model, class_names, criterion,
          args):
    '''
        Train the model
        Use val_data to do early stopping
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(
        os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    if args.STS == True:
        classes_sample_p, example_prob_metrix = pre_calculate(
            train_data, class_names, model['G'], args)
    else:
        classes_sample_p, example_prob_metrix = None, None

    optG = torch.optim.Adam(grad_param(model, ['G']),
                            lr=args.meta_lr,
                            weight_decay=args.weight_decay)
    # optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr)
    # optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr)

    if args.lr_scheduler == 'ReduceLROnPlateau':
        schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True)
        # schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optCLF, 'max', patience=args.patience // 2, factor=0.1, verbose=True)

    elif args.lr_scheduler == 'ExponentialLR':
        schedulerG = torch.optim.lr_scheduler.ExponentialLR(
            optG, gamma=args.ExponentialLR_gamma)
        # schedulerCLF = torch.optim.lr_scheduler.ExponentialLR(optCLF, gamma=args.ExponentialLR_gamma)

    print("{}, Start training".format(datetime.datetime.now()), flush=True)

    # sampled_classes, source_classes = task_sampler(train_data, args)
    acc = 0
    loss = 0
    for ep in range(args.train_epochs):
        ep_loss = 0
        for _ in range(args.train_episodes):

            sampled_classes, source_classes = task_sampler(
                train_data, args, classes_sample_p)

            train_gen = SerialSampler(train_data, args, sampled_classes,
                                      source_classes, 1, example_prob_metrix)

            sampled_tasks = train_gen.get_epoch()

            grad = {'clf': [], 'G': []}

            if not args.notqdm:
                sampled_tasks = tqdm(sampled_tasks,
                                     total=train_gen.num_episodes,
                                     ncols=80,
                                     leave=False,
                                     desc=colored('Training on train',
                                                  'yellow'))

            for task in sampled_tasks:
                if task is None:
                    break
                q_loss, q_acc = train_one(task, class_names, model, optG,
                                          criterion, args, grad)
                acc += q_acc
                loss = loss + q_loss
                ep_loss = ep_loss + q_loss

        ep_loss = ep_loss / args.train_episodes

        optG.zero_grad()
        ep_loss.backward()
        optG.step()

        if ep % 100 == 0:
            print("{}:".format(colored('--------[TRAIN] ep', 'blue')) +
                  str(ep) + ", loss:" + str(q_loss.item()) + ", acc:" +
                  str(q_acc.item()) + "-----------")

        test_count = 100
        # if (ep % test_count == 0) and (ep != 0):
        if (ep % test_count == 0):
            acc = acc / args.train_episodes / test_count
            loss = loss / args.train_episodes / test_count
            print("{}:".format(colored('--------[TRAIN] ep', 'blue')) +
                  str(ep) + ", mean_loss:" + str(loss.item()) + ", mean_acc:" +
                  str(acc.item()) + "-----------")

            net = copy.deepcopy(model)
            # acc, std = test(train_data, class_names, optG, net, criterion, args, args.test_epochs, False)
            # print("[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
            #     datetime.datetime.now(),
            #     "ep", ep,
            #     colored("train", "red"),
            #     colored("acc:", "blue"), acc, std,
            #     ), flush=True)
            acc = 0
            loss = 0

            # Evaluate test accuracy
            cur_acc, cur_std = test(test_data, class_names, optG, net,
                                    criterion, args, args.test_epochs, False)
            print(
                ("[TEST] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, ").
                format(
                    datetime.datetime.now(),
                    "ep",
                    ep,
                    colored("test  ", "cyan"),
                    colored("acc:", "blue"),
                    cur_acc,
                    cur_std,
                    # colored("train stats", "cyan"),
                    # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])),
                    # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
                ),
                flush=True)

            # Evaluate validation accuracy
            cur_acc, cur_std = test(val_data, class_names, optG, net,
                                    criterion, args, args.test_epochs, False)
            print(
                ("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, ").
                format(
                    datetime.datetime.now(),
                    "ep",
                    ep,
                    colored("val  ", "cyan"),
                    colored("acc:", "blue"),
                    cur_acc,
                    cur_std,
                    # colored("train stats", "cyan"),
                    # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])),
                    # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
                ),
                flush=True)

            # Update the current best model if val acc is better
            if cur_acc > best_acc:
                best_acc = cur_acc
                best_path = os.path.join(out_dir, str(ep))

                # save current model
                print("{}, Save cur best model to {}".format(
                    datetime.datetime.now(), best_path))

                torch.save(model['G'].state_dict(), best_path + '.G')
                # torch.save(model['G2'].state_dict(), best_path + '.G2')
                # torch.save(model['clf'].state_dict(), best_path + '.clf')

                sub_cycle = 0
            else:
                sub_cycle += 1

            # Break if the val acc hasn't improved in the past patience epochs
            if sub_cycle == args.patience:
                break

            if args.lr_scheduler == 'ReduceLROnPlateau':
                schedulerG.step(cur_acc)
                # schedulerCLF.step(cur_acc)

            elif args.lr_scheduler == 'ExponentialLR':
                schedulerG.step()
                # schedulerCLF.step()

    print("{}, End of training. Restore the best weights".format(
        datetime.datetime.now()),
          flush=True)

    # restore the best saved model
    model['G'].load_state_dict(torch.load(best_path + '.G'))
    # model['G2'].load_state_dict(torch.load(best_path + '.G2'))
    # model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "saved-runs",
                         str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(datetime.datetime.now(),
                                                 best_path),
              flush=True)

        torch.save(model['G'].state_dict(), best_path + '.G')
        # torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return optG
def test(test_data,
         class_names,
         optG,
         model,
         criterion,
         args,
         num_episodes,
         verbose=True):
    '''
        Evaluate the model on a bag of sampled tasks. Return the mean accuracy
        and its std.
    '''
    model['G'].train()

    acc = []
    for ep in range(num_episodes):
        if ep % 100 == 0:
            print(ep)
        # if args.embedding == 'mlada':
        #     acc1, d_acc1, sentence_ebd, avg_sentence_ebd, sentence_label, word_weight, query_data, x_hat = test_one(task, model, args)
        #     if count < 20:
        #         if all_sentence_ebd is None:
        #             all_sentence_ebd = sentence_ebd
        #             all_avg_sentence_ebd = avg_sentence_ebd
        #             all_sentence_label = sentence_label
        #             all_word_weight = word_weight
        #             all_query_data = query_data
        #             all_x_hat = x_hat
        #         else:
        #             all_sentence_ebd = np.concatenate((all_sentence_ebd, sentence_ebd), 0)
        #             all_avg_sentence_ebd = np.concatenate((all_avg_sentence_ebd, avg_sentence_ebd), 0)
        #             all_sentence_label = np.concatenate((all_sentence_label, sentence_label))
        #             all_word_weight = np.concatenate((all_word_weight, word_weight), 0)
        #             all_query_data = np.concatenate((all_query_data, query_data), 0)
        #             all_x_hat = np.concatenate((all_x_hat, x_hat), 0)
        #     count = count + 1
        #     acc.append(acc1)
        #     d_acc.append(d_acc1)
        # else:
        #     acc.append(test_one(task, model, args))
        sampled_classes, source_classes = task_sampler(test_data, args)
        # class_names_dict = {}
        # class_names_dict['label'] = class_names['label'][sampled_classes]
        # class_names_dict['text'] = class_names['text'][sampled_classes]
        # class_names_dict['text_len'] = class_names['text_len'][sampled_classes]
        # class_names_dict['is_support'] = False

        train_gen = SerialSampler(test_data, args, sampled_classes,
                                  source_classes, args.train_episodes)

        sampled_tasks = train_gen.get_epoch()
        # class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support'])

        grad = {'clf': [], 'G': []}

        if not args.notqdm:
            sampled_tasks = tqdm(sampled_tasks,
                                 total=train_gen.num_episodes,
                                 ncols=80,
                                 leave=False,
                                 desc=colored('Training on train', 'yellow'))

        for task in sampled_tasks:
            if task is None:
                break
            q_acc = test_one(task, class_names, model, optG, criterion, args,
                             grad)
            acc.append(q_acc.cpu().item())

    acc = np.array(acc)

    if verbose:
        if args.embedding != 'mlada':
            print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now(),
                colored("test acc mean", "blue"),
                np.mean(acc),
                colored("test std", "blue"),
                np.std(acc),
            ),
                  flush=True)
        else:
            print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now(),
                colored("test acc mean", "blue"),
                np.mean(acc),
                colored("test std", "blue"),
                np.std(acc),
            ),
                  flush=True)

    return np.mean(acc), np.std(acc)