예제 #1
0
def main():
    test_dir = opt.test_dir
    feature_param_file = opt.feat
    class_param_file = opt.cls
    bsize = opt.b

    # models
    if 'vgg' == opt.i:
        feature = Vgg16()
    elif 'resnet' == opt.i:
        feature = resnet50()
    elif 'densenet' == opt.i:
        feature = densenet121()
    feature.cuda()
    # feature.load_state_dict(torch.load(feature_param_file))
    feature.eval()

    classifier = Classifier(opt.i)
    classifier.cuda()
    # classifier.load_state_dict(torch.load(class_param_file))
    classifier.eval()

    loader = torch.utils.data.DataLoader(MyClsTestData(test_dir,
                                                       transform=True),
                                         batch_size=bsize,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True)
    acc = eval_acc(feature, classifier, loader)
    print acc
def evaluate(checkpoint_path,
             num_class,
             num_words,
             datafolds,
             glove=None,
             use_gpu=False):
    checkpoint = torch.load(checkpoint_path)
    config_string = checkpoint['config_string']
    groups = re.search(r'input(\d+)_hidden(\d+)', config_string)
    input_size, hidden_size = int(groups.group(1)), int(groups.group(2))
    classifier = Classifier(input_size, hidden_size, num_class, num_words,
                            glove, use_gpu)
    if use_gpu:
        classifier = classifier.cuda()
    dataset_eval = datafolds[-1]
    classifier.load_state_dict(checkpoint['model'])
    correct, total = classifier.evalute_dataset(dataset_eval)
    return correct, total
예제 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--train-set',
        type=str,
        dest="train_set",
        default="/store/slowmoyang/TopTagging/toptagging-training.root")
    parser.add_argument(
        '--valid-set',
        type=str,
        dest="valid_set",
        default="/store/slowmoyang/TopTagging/toptagging-validation.root")
    parser.add_argument(
        '--test-set',
        dest="test_set",
        default="/store/slowmoyang/TopTagging/toptagging-test.root",
        type=str)
    parser.add_argument('--batch-size',
                        dest="batch_size",
                        default=128,
                        type=int,
                        help='batch size')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=2048,
                        dest="test_batch_size",
                        help='batch size for test and validation')
    parser.add_argument('--epoch',
                        dest="num_epochs",
                        default=2,
                        type=int,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        default=0.005,
                        type=float,
                        help='learning rate, default=0.005')
    parser.add_argument("--logdir",
                        dest="log_dir",
                        default="./logs/untitled",
                        type=str,
                        help="the path to direactory")
    parser.add_argument("--verbose", default=True, type=bool)
    args = parser.parse_args()

    #####################################
    #
    ######################################
    log_dir = Directory(args.log_dir, create=True)
    log_dir.mkdir("state_dict")
    log_dir.mkdir("validation")
    log_dir.mkdir("roc_curve")

    ##################################################
    # Logger
    ################################################
    logger = logging.getLogger("TopTagging")
    logger.setLevel(logging.INFO)

    format_str = '[%(asctime)s] %(message)s'
    date_format = '%Y-%m-%d %H:%M:%S'
    formatter = logging.Formatter(format_str, date_format)

    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    log_file_path = log_dir.concat("log.txt")
    file_handler = logging.FileHandler(log_file_path)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    ###############################
    #
    ####################################
    device = torch.device("cuda:0")

    ######################################
    # Data Loader
    ########################################
    train_loader = get_data_loader(path=args.train_set,
                                   batch_size=args.batch_size)

    valid_loader = get_data_loader(path=args.valid_set,
                                   batch_size=args.test_batch_size)

    test_loader = get_data_loader(path=args.test_set,
                                  batch_size=args.test_batch_size)

    #####################
    # Model
    ######################
    model = Classifier()
    model.apply(init_weights)
    model.cuda(device)

    if args.verbose:
        logger.info(model)

    ##################################
    # Objective, optimizer,
    ##################################
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    ################################
    # Callbacks
    ################################
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=args.verbose)

    ########################################
    # NOTE
    #######################################
    for epoch in xrange(1, args.num_epochs + 1):
        logger.info("Epoch: [{:d}/{:d}]".format(epoch, args.num_epochs))

        train(model=model,
              data_loader=train_loader,
              optimizer=optimizer,
              criterion=criterion,
              device=device,
              logger=logger)

        results = validate(model, valid_loader, device, logger)

        # Callbacks
        scheduler.step(results["loss"])

        save_model(model, log_dir, epoch, results)

    good_states = find_good_state(log_dir.state_dict.path)
    for each in good_states:
        model.load_state_dict(torch.load(each))
        # evaluate(model, test_loader, log_dir)

    logger.info("END")
예제 #4
0
class ModelBuilder(object):
    def __init__(self, use_cuda):
        self.cuda = use_cuda
        self._pre_data()
        self._build_model()
        self.i_mb = 0

    def _pre_data(self):
        print('pre data...')
        self.data = Data(self.cuda)

    def _build_model(self):
        print('building model...')
        we = torch.load('./data/processed/we.pkl')
        self.i_encoder = CNN_Args_encoder(we)
        self.a_encoder = CNN_Args_encoder(we, need_kmaxavg=True)
        self.classifier = Classifier()
        self.discriminator = Discriminator()
        if self.cuda:
            self.i_encoder.cuda()
            self.a_encoder.cuda()
            self.classifier.cuda()
            self.discriminator.cuda()
        self.criterion_c = torch.nn.CrossEntropyLoss()
        self.criterion_d = torch.nn.BCELoss()
        para_filter = lambda model: filter(lambda p: p.requires_grad,
                                           model.parameters())
        self.i_optimizer = torch.optim.Adagrad(para_filter(self.i_encoder),
                                               Config.lr,
                                               weight_decay=Config.l2_penalty)
        self.a_optimizer = torch.optim.Adagrad(para_filter(self.a_encoder),
                                               Config.lr,
                                               weight_decay=Config.l2_penalty)
        self.c_optimizer = torch.optim.Adagrad(self.classifier.parameters(),
                                               Config.lr,
                                               weight_decay=Config.l2_penalty)
        self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                            Config.lr_d,
                                            weight_decay=Config.l2_penalty)

    def _print_train(self, epoch, time, loss, acc):
        print('-' * 80)
        print(
            '| end of epoch {:3d} | time: {:5.2f}s | loss: {:10.5f} | acc: {:5.2f}% |'
            .format(epoch, time, loss, acc * 100))
        print('-' * 80)

    def _print_eval(self, task, loss, acc):
        print('| ' + task +
              ' loss {:10.5f} | acc {:5.2f}% |'.format(loss, acc * 100))
        print('-' * 80)

    def _save_model(self, model, filename):
        torch.save(model.state_dict(), './weights/' + filename)

    def _load_model(self, model, filename):
        model.load_state_dict(torch.load('./weights/' + filename))

    def _pretrain_i_one(self):
        self.i_encoder.train()
        self.classifier.train()
        total_loss = 0
        correct_n = 0
        for a1, a2i, a2a, sense in self.data.train_loader:
            if self.cuda:
                a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense.cuda()
            a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable(
                a2a), Variable(sense)

            output = self.classifier(self.i_encoder(a1, a2i))
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n += torch.sum(tmp).data

            loss = self.criterion_c(output, sense)
            self.i_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.i_optimizer.step()
            self.c_optimizer.step()

            total_loss += loss.data * sense.size(0)
        return total_loss[0] / self.data.train_size, correct_n[
            0] / self.data.train_size

    def _pretrain_i_a_one(self):
        self.i_encoder.train()
        self.a_encoder.train()
        self.classifier.train()
        total_loss = 0
        correct_n = 0
        total_loss_a = 0
        correct_n_a = 0
        for a1, a2i, a2a, sense in self.data.train_loader:
            if self.cuda:
                a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense.cuda()
            a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable(
                a2a), Variable(sense)

            # train i
            output = self.classifier(self.i_encoder(a1, a2i))
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n += torch.sum(tmp).data

            loss = self.criterion_c(output, sense)
            self.i_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.i_optimizer.step()
            self.c_optimizer.step()

            total_loss += loss.data * sense.size(0)

            #train a
            output = self.classifier(self.a_encoder(a1, a2a))
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n_a += torch.sum(tmp).data

            loss = self.criterion_c(output, sense)
            self.a_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.a_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.a_optimizer.step()
            self.c_optimizer.step()

            total_loss_a += loss.data * sense.size(0)
        return total_loss[0] / self.data.train_size, correct_n[
            0] / self.data.train_size, total_loss_a[
                0] / self.data.train_size, correct_n_a[0] / self.data.train_size

    def _adtrain_one(self, acc_d_for_train):
        total_loss = 0
        total_loss_2 = 0
        correct_n = 0
        correct_n_d = 0
        correct_n_d_for_train = 0
        for a1, a2i, a2a, sense in self.data.train_loader:
            if self.cuda:
                a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense.cuda()
            a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable(
                a2a), Variable(sense)

            # phase 1, train discriminator
            flag = 0
            for k in range(Config.kd):
                # if self._test_d() != 1:
                if True:
                    temp_d = 0
                    self.a_encoder.eval()
                    self.i_encoder.eval()
                    self.discriminator.train()
                    self.d_optimizer.zero_grad()
                    output_i = self.discriminator(self.i_encoder(a1, a2i))
                    temp_d += torch.sum((output_i < 0.5).long()).data
                    # zero_tensor = torch.zeros(output_i.size())
                    zero_tensor = torch.Tensor(output_i.size()).random_(
                        0, 100) * 0.003
                    if self.cuda:
                        zero_tensor = zero_tensor.cuda()
                    zero_tensor = Variable(zero_tensor)
                    d_loss_i = self.criterion_d(output_i, zero_tensor)
                    d_loss_i.backward()
                    output_a = self.discriminator(self.a_encoder(a1, a2a))
                    temp_d += torch.sum((output_a >= 0.5).long()).data
                    # one_tensor = torch.ones(output_a.size())
                    # one_tensor = torch.Tensor(output_a.size()).fill_(Config.alpha)
                    one_tensor = torch.Tensor(output_a.size()).random_(
                        0, 100) * 0.005 + 0.7
                    if self.cuda:
                        one_tensor = one_tensor.cuda()
                    one_tensor = Variable(one_tensor)
                    d_loss_a = self.criterion_d(output_a, one_tensor)
                    d_loss_a.backward()
                    correct_n_d_for_train += temp_d
                    temp_d = max(temp_d[0] / sense.size(0) / 2,
                                 acc_d_for_train)
                    if temp_d < Config.thresh_high:
                        torch.nn.utils.clip_grad_norm(
                            self.discriminator.parameters(), Config.grad_clip)
                        self.d_optimizer.step()

            # phase 2, train i/c
            self.i_encoder.train()
            self.classifier.train()
            self.discriminator.eval()
            self.i_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            sent_repr = self.i_encoder(a1, a2i)

            output = self.classifier(sent_repr)
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n += torch.sum(tmp).data
            loss_1 = self.criterion_c(output, sense)

            output_d = self.discriminator(sent_repr)
            correct_n_d += torch.sum((output_d < 0.5).long()).data
            one_tensor = torch.ones(output_d.size())
            # one_tensor = torch.Tensor(output_d.size()).fill_(Config.alpha)
            # one_tensor = torch.Tensor(output_d.size()).random_(0,100) * 0.005 + 0.7
            if self.cuda:
                one_tensor = one_tensor.cuda()
            one_tensor = Variable(one_tensor)
            loss_2 = self.criterion_d(output_d, one_tensor)

            loss = loss_1 + loss_2 * Config.lambda1
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.i_optimizer.step()
            self.c_optimizer.step()

            total_loss += loss.data * sense.size(0)
            total_loss_2 += loss_2.data * sense.size(0)

            test_loss, test_acc = self._eval('test', 'i')
            self.logwriter.add_scalar('acc/test_acc_t_mb', test_acc * 100,
                                      self.i_mb)
            self.i_mb += 1

        return total_loss[0] / self.data.train_size, correct_n[
            0] / self.data.train_size, correct_n_d[
                0] / self.data.train_size, total_loss_2[
                    0] / self.data.train_size, correct_n_d_for_train[
                        0] / self.data.train_size / 2

    def _pretrain_i(self):
        best_test_acc = 0
        for epoch in range(Config.pre_i_epochs):
            start_time = time.time()
            loss, acc = self._pretrain_i_one()
            self._print_train(epoch, time.time() - start_time, loss, acc)
            self.logwriter.add_scalar('loss/train_loss_i', loss, epoch)
            self.logwriter.add_scalar('acc/train_acc_i', acc * 100, epoch)

            dev_loss, dev_acc = self._eval('dev', 'i')
            self._print_eval('dev', dev_loss, dev_acc)
            self.logwriter.add_scalar('loss/dev_loss_i', dev_loss, epoch)
            self.logwriter.add_scalar('acc/dev_acc_i', dev_acc * 100, epoch)

            test_loss, test_acc = self._eval('test', 'i')
            self._print_eval('test', test_loss, test_acc)
            self.logwriter.add_scalar('loss/test_loss_i', test_loss, epoch)
            self.logwriter.add_scalar('acc/test_acc_i', test_acc * 100, epoch)
            if test_acc >= best_test_acc:
                best_test_acc = test_acc
                self._save_model(self.i_encoder, 'i.pkl')
                self._save_model(self.classifier, 'c.pkl')
                print('i_model saved at epoch {}'.format(epoch))

    def _adjust_learning_rate(self, optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def _train_together(self):
        best_test_acc = 0
        loss = acc = loss_a = acc_a = 0
        lr_t = Config.lr_t
        acc_d_for_train = 0
        for epoch in range(Config.together_epochs):
            start_time = time.time()
            if epoch < Config.first_stage_epochs:
                loss, acc, loss_a, acc_a = self._pretrain_i_a_one()
            else:
                if epoch == Config.first_stage_epochs:
                    self._adjust_learning_rate(self.i_optimizer, lr_t)
                    self._adjust_learning_rate(self.c_optimizer, lr_t / 2)
                # elif (epoch - Config.first_stage_epochs) % 20 == 0:
                #     lr_t *= 0.8
                #     self._adjust_learning_rate(self.i_optimizer, lr_t)
                #     self._adjust_learning_rate(self.c_optimizer, lr_t)
                loss, acc, acc_d, loss_2, acc_d_for_train = self._adtrain_one(
                    acc_d_for_train)
            self._print_train(epoch, time.time() - start_time, loss, acc)
            self.logwriter.add_scalar('loss/train_loss_t', loss, epoch)
            self.logwriter.add_scalar('acc/train_acc_t', acc * 100, epoch)
            self.logwriter.add_scalar('loss/train_loss_t_a', loss_a, epoch)
            self.logwriter.add_scalar('acc/train_acc_t_a', acc_a * 100, epoch)
            if epoch >= Config.first_stage_epochs:
                self.logwriter.add_scalar('acc/train_acc_d', acc_d * 100,
                                          epoch)
                self.logwriter.add_scalar('loss/train_loss_2', loss_2, epoch)
                self.logwriter.add_scalar('acc/acc_d_for_train',
                                          acc_d_for_train * 100, epoch)

            dev_loss, dev_acc = self._eval('dev', 'i')
            dev_loss_a, dev_acc_a = self._eval('dev', 'a')
            self._print_eval('dev', dev_loss, dev_acc)
            self.logwriter.add_scalar('loss/dev_loss_t', dev_loss, epoch)
            self.logwriter.add_scalar('acc/dev_acc_t', dev_acc * 100, epoch)
            self.logwriter.add_scalar('loss/dev_loss_t_a', dev_loss_a, epoch)
            self.logwriter.add_scalar('acc/dev_acc_t_a', dev_acc_a * 100,
                                      epoch)
            if epoch >= Config.first_stage_epochs:
                dev_acc_d = self._eval_d('dev')
                self.logwriter.add_scalar('acc/dev_acc_d', dev_acc_d * 100,
                                          epoch)

            test_loss, test_acc = self._eval('test', 'i')
            test_loss_a, test_acc_a = self._eval('test', 'a')
            self._print_eval('test', test_loss, test_acc)
            self.logwriter.add_scalar('loss/test_loss_t', test_loss, epoch)
            self.logwriter.add_scalar('acc/test_acc_t', test_acc * 100, epoch)
            self.logwriter.add_scalar('loss/test_loss_t_a', test_loss_a, epoch)
            self.logwriter.add_scalar('acc/test_acc_t_a', test_acc_a * 100,
                                      epoch)
            if epoch >= Config.first_stage_epochs:
                test_acc_d = self._eval_d('test')
                self.logwriter.add_scalar('acc/test_acc_d', test_acc_d * 100,
                                          epoch)
            if test_acc >= best_test_acc:
                best_test_acc = test_acc
                self._save_model(self.i_encoder, 't_i.pkl')
                self._save_model(self.classifier, 't_c.pkl')
                print('t_i t_c saved at epoch {}'.format(epoch))

    def train(self, i_or_t):
        print('start training')
        self.logwriter = SummaryWriter(Config.logdir)
        if i_or_t == 'i':
            self._pretrain_i()
        elif i_or_t == 't':
            self._train_together()
        else:
            raise Exception('wrong i_or_t')
        print('training done')

    def _eval(self, task, i_or_a):
        self.i_encoder.eval()
        self.a_encoder.eval()
        self.classifier.eval()
        total_loss = 0
        correct_n = 0
        if task == 'dev':
            data = self.data.dev_loader
            n = self.data.dev_size
        elif task == 'test':
            data = self.data.test_loader
            n = self.data.test_size
        else:
            raise Exception('wrong eval task')
        for a1, a2i, a2a, sense1, sense2 in data:
            if self.cuda:
                a1, a2i, a2a, sense1, sense2 = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense1.cuda(), sense2.cuda()
            a1 = Variable(a1, volatile=True)
            a2i = Variable(a2i, volatile=True)
            a2a = Variable(a2a, volatile=True)
            sense1 = Variable(sense1, volatile=True)
            sense2 = Variable(sense2, volatile=True)

            if i_or_a == 'i':
                output = self.classifier(self.i_encoder(a1, a2i))
            elif i_or_a == 'a':
                output = self.classifier(self.a_encoder(a1, a2a))
            else:
                raise Exception('wrong i_or_a')
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense1.size()
            gold_sense = sense1
            mask = (output_sense == sense2)
            gold_sense[mask] = sense2[mask]
            tmp = (output_sense == gold_sense).long()
            correct_n += torch.sum(tmp).data

            loss = self.criterion_c(output, gold_sense)
            total_loss += loss.data * gold_sense.size(0)
        return total_loss[0] / n, correct_n[0] / n

    def _eval_d(self, task):
        self.i_encoder.eval()
        self.a_encoder.eval()
        self.classifier.eval()
        correct_n = 0
        if task == 'train':
            n = self.data.train_size
            for a1, a2i, a2a, sense in self.data.train_loader:
                if self.cuda:
                    a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                    ), sense.cuda()
                a1 = Variable(a1, volatile=True)
                a2i = Variable(a2i, volatile=True)
                a2a = Variable(a2a, volatile=True)
                sense = Variable(sense, volatile=True)

                output_i = self.discriminator(self.i_encoder(a1, a2i))
                correct_n += torch.sum((output_i < 0.5).long()).data
                # output_a = self.discriminator(self.a_encoder(a1, a2a))
                # correct_n += torch.sum((output_a >= 0.5).long()).data
        else:
            if task == 'dev':
                data = self.data.dev_loader
                n = self.data.dev_size
            elif task == 'test':
                data = self.data.test_loader
                n = self.data.test_size
            for a1, a2i, a2a, sense1, sense2 in data:
                if self.cuda:
                    a1, a2i, a2a, sense1, sense2 = a1.cuda(), a2i.cuda(
                    ), a2a.cuda(), sense1.cuda(), sense2.cuda()
                a1 = Variable(a1, volatile=True)
                a2i = Variable(a2i, volatile=True)
                a2a = Variable(a2a, volatile=True)
                sense1 = Variable(sense1, volatile=True)
                sense2 = Variable(sense2, volatile=True)

                output_i = self.discriminator(self.i_encoder(a1, a2i))
                correct_n += torch.sum((output_i < 0.5).long()).data
                # output_a = self.discriminator(self.a_encoder(a1, a2a))
                # correct_n += torch.sum((output_a >= 0.5).long()).data
        return correct_n[0] / n

    def _test_d(self):
        acc = self._eval_d('dev')
        phase = -100
        if acc >= Config.thresh_high:
            phase = 1
        elif acc > Config.thresh_low:
            phase = 0
        else:
            phase = -1
        return phase

    def eval(self, stage):
        if stage == 'i':
            self._load_model(self.i_encoder, 'i.pkl')
            self._load_model(self.classifier, 'c.pkl')
            test_loss, test_acc = self._eval('test', 'i')
            self._print_eval('test', test_loss, test_acc)
        elif stage == 't':
            self._load_model(self.i_encoder, 't_i.pkl')
            self._load_model(self.classifier, 't_c.pkl')
            test_loss, test_acc = self._eval('test', 'i')
            self._print_eval('test', test_loss, test_acc)
        else:
            raise Exception('wrong eval stage')
class ModelBuilder(object):
    def __init__(self, use_cuda, conf):
        self.cuda = use_cuda
        self.conf = conf
        self._pre_data()
        self._build_model()

    def _pre_data(self):
        print('pre data...')
        self.data = Data(self.cuda, self.conf)

    def _build_model(self):
        print('loading embedding...')
        if self.conf.corpus_splitting == 1:
            pre = './data/processed/lin/'
        elif self.conf.corpus_splitting == 2:
            pre = './data/processed/ji/'
        elif self.conf.corpus_splitting == 3:
            pre = './data/processed/l/'
        we = torch.load(pre + 'we.pkl')
        char_table = None
        sub_table = None
        if self.conf.need_char or self.conf.need_elmo:
            char_table = torch.load(pre + 'char_table.pkl')
        if self.conf.need_sub:
            sub_table = torch.load(pre + 'sub_table.pkl')
        print('building model...')
        self.encoder = ArgEncoder(self.conf, we, char_table, sub_table,
                                  self.cuda)
        self.classifier = Classifier(self.conf.clf_class_num, self.conf)
        if self.conf.is_mttrain:
            self.conn_classifier = Classifier(self.conf.conn_num, self.conf)
        if self.cuda:
            self.encoder.cuda()
            self.classifier.cuda()
            if self.conf.is_mttrain:
                self.conn_classifier.cuda()
        self.criterion = torch.nn.CrossEntropyLoss()
        para_filter = lambda model: filter(lambda p: p.requires_grad,
                                           model.parameters())
        self.e_optimizer = torch.optim.Adagrad(
            para_filter(self.encoder),
            self.conf.lr,
            weight_decay=self.conf.l2_penalty)
        self.c_optimizer = torch.optim.Adagrad(
            para_filter(self.classifier),
            self.conf.lr,
            weight_decay=self.conf.l2_penalty)
        if self.conf.is_mttrain:
            self.con_optimizer = torch.optim.Adagrad(
                para_filter(self.conn_classifier),
                self.conf.lr,
                weight_decay=self.conf.l2_penalty)

    def _print_train(self, epoch, time, loss, acc):
        print('-' * 80)
        print(
            '| end of epoch {:3d} | time: {:5.2f}s | loss: {:10.5f} | acc: {:5.2f}% |'
            .format(epoch, time, loss, acc * 100))
        print('-' * 80)

    def _print_eval(self, task, loss, acc, f1):
        print('| ' + task +
              ' loss {:10.5f} | acc {:5.2f}% | f1 {:5.2f}%'.format(
                  loss, acc * 100, f1 * 100))
        print('-' * 80)

    def _save_model(self, model, filename):
        torch.save(model.state_dict(), './weights/' + filename)

    def _load_model(self, model, filename):
        model.load_state_dict(torch.load('./weights/' + filename))

    def _train_one(self):
        self.encoder.train()
        self.classifier.train()
        if self.conf.is_mttrain:
            self.conn_classifier.train()
        total_loss = 0
        correct_n = 0
        train_size = self.data.train_size
        for a1, a2, sense, conn in self.data.train_loader:
            if self.conf.four_or_eleven == 2:
                mask1 = (sense == self.conf.binclass)
                mask2 = (sense != self.conf.binclass)
                sense[mask1] = 1
                sense[mask2] = 0
            if self.cuda:
                a1, a2, sense, conn = a1.cuda(), a2.cuda(), sense.cuda(
                ), conn.cuda()
            a1, a2, sense, conn = Variable(a1), Variable(a2), Variable(
                sense), Variable(conn)
            repr = self.encoder(a1, a2)
            output = self.classifier(repr)
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n += torch.sum(tmp).data
            loss = self.criterion(output, sense)

            if self.conf.is_mttrain:
                conn_output = self.conn_classifier(repr)
                loss2 = self.criterion(conn_output, conn)
                loss = loss + loss2 * self.conf.lambda1

            self.e_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            if self.conf.is_mttrain:
                self.con_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.encoder.parameters(),
                                          self.conf.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          self.conf.grad_clip)
            if self.conf.is_mttrain:
                torch.nn.utils.clip_grad_norm(
                    self.conn_classifier.parameters(), self.conf.grad_clip)
            self.e_optimizer.step()
            self.c_optimizer.step()
            if self.conf.is_mttrain:
                self.con_optimizer.step()

            total_loss += loss.data * sense.size(0)
        return total_loss[0] / train_size, correct_n[0] / train_size

    def _train(self, pre):
        for epoch in range(self.conf.epochs):
            start_time = time.time()
            loss, acc = self._train_one()
            self._print_train(epoch, time.time() - start_time, loss, acc)
            self.logwriter.add_scalar('loss/train_loss', loss, epoch)
            self.logwriter.add_scalar('acc/train_acc', acc * 100, epoch)

            dev_loss, dev_acc, dev_f1 = self._eval('dev')
            self._print_eval('dev', dev_loss, dev_acc, dev_f1)
            self.logwriter.add_scalar('loss/dev_loss', dev_loss, epoch)
            self.logwriter.add_scalar('acc/dev_acc', dev_acc * 100, epoch)
            self.logwriter.add_scalar('f1/dev_f1', dev_f1 * 100, epoch)

            test_loss, test_acc, test_f1 = self._eval('test')
            self._print_eval('test', test_loss, test_acc, test_f1)
            self.logwriter.add_scalar('loss/test_loss', test_loss, epoch)
            self.logwriter.add_scalar('acc/test_acc', test_acc * 100, epoch)
            self.logwriter.add_scalar('f1/test_f1', test_f1 * 100, epoch)

    def train(self, pre):
        print('start training')
        self.logwriter = SummaryWriter(self.conf.logdir)
        self._train(pre)
        print('training done')

    def _eval(self, task):
        self.encoder.eval()
        self.classifier.eval()
        total_loss = 0
        correct_n = 0
        if task == 'dev':
            data = self.data.dev_loader
            n = self.data.dev_size
        elif task == 'test':
            data = self.data.test_loader
            n = self.data.test_size
        else:
            raise Exception('wrong eval task')
        output_list = []
        gold_list = []
        for a1, a2, sense1, sense2 in data:
            if self.conf.four_or_eleven == 2:
                mask1 = (sense1 == self.conf.binclass)
                mask2 = (sense1 != self.conf.binclass)
                sense1[mask1] = 1
                sense1[mask2] = 0
                mask0 = (sense2 == -1)
                mask1 = (sense2 == self.conf.binclass)
                mask2 = (sense2 != self.conf.binclass)
                sense2[mask1] = 1
                sense2[mask2] = 0
                sense2[mask0] = -1
            if self.cuda:
                a1, a2, sense1, sense2 = a1.cuda(), a2.cuda(), sense1.cuda(
                ), sense2.cuda()
            a1 = Variable(a1, volatile=True)
            a2 = Variable(a2, volatile=True)
            sense1 = Variable(sense1, volatile=True)
            sense2 = Variable(sense2, volatile=True)

            output = self.classifier(self.encoder(a1, a2))
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense1.size()
            gold_sense = sense1
            mask = (output_sense == sense2)
            gold_sense[mask] = sense2[mask]
            tmp = (output_sense == gold_sense).long()
            correct_n += torch.sum(tmp).data

            output_list.append(output_sense)
            gold_list.append(gold_sense)

            loss = self.criterion(output, gold_sense)
            total_loss += loss.data * gold_sense.size(0)

        output_s = torch.cat(output_list)
        gold_s = torch.cat(gold_list)
        if self.conf.four_or_eleven == 2:
            f1 = f1_score(gold_s.cpu().data.numpy(),
                          output_s.cpu().data.numpy(),
                          average='binary')
        else:
            f1 = f1_score(gold_s.cpu().data.numpy(),
                          output_s.cpu().data.numpy(),
                          average='macro')
        return total_loss[0] / n, correct_n[0] / n, f1

    def eval(self, pre):
        print('evaluating...')
        self._load_model(self.encoder, pre + '_eparams.pkl')
        self._load_model(self.classifier, pre + '_cparams.pkl')
        test_loss, test_acc, f1 = self._eval('test')
        self._print_eval('test', test_loss, test_acc, f1)
예제 #6
0
                                batch_size=batchSize,
                                shuffle=True,
                                num_workers=4)

    numClass = 4
    numFeat = 32
    dropout = 0.25
    modelConv = DownSampler(numFeat, False, dropout)
    modelClass = Classifier(numFeat * 2, numClass, 4)
    modelHess = BNNL()
    if hessMC:
        modelHess = BNNMC()
    weights = torch.ones(numClass)
    if torch.cuda.is_available():
        modelConv = modelConv.cuda()
        modelClass = modelClass.cuda()
        modelHess = modelHess.cuda()
        weights = weights.cuda()

    criterion = torch.nn.CrossEntropyLoss(weights)

    mapLoc = None if torch.cuda.is_available() else {'cuda:0': 'cpu'}

    epochs = 80
    lr = 1e-2
    weight_decay = 5e-4
    momentum = 0.9

    def cb():
        print("Best Model reloaded")
        if hessMC:
예제 #7
0
class Solver(object):

    def __init__(self, Msceleb_loader, config):
        # Data loader
        self.Msceleb_loader = Msceleb_loader

        # Model hyper-parameters
        self.c_dim = config.c_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        # self.dataset = config.dataset
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model

        # Test settings
        self.test_model = config.test_model

        # Path
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.result_path = config.result_path

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        # self.lambda_face = 0.0
        # if self.lambda_face > 0.0:
        #     self.Face_recognition_network = face_recognition_networks.LightCNN_29Layers(num_classes=79077)
        #     self.Face_recognition_network = torch.nn.DataParallel(self.Face_recognition_network).cuda()
        #     checkpoint = torch.load(r'/data5/shentao/LightCNN/CNN_29.pkl')
        #     self.Face_recognition_network.load_state_dict(checkpoint)
        #     for param in self.Face_recognition_network.parameters():
        #         param.requires_grad = False
        #     self.Face_recognition_network.eval()

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):

        self.G = Generator(self.g_conv_dim, self.g_repeat_num)
        self.D = Discriminator(self.d_conv_dim, self.d_repeat_num)
        self.C = Classifier(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)

        # Optimizers
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.d_lr, [self.beta1, self.beta2])

        # Print networks
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
        self.print_network(self.C, 'C')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()
            self.C.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        self.C.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_C.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x[x >= 0.5] = 1
        x[x < 0.5] = 0
        return x

    def compute_accuracy(self, x, y):
        _, predicted = torch.max(x, dim=1)
        correct = (predicted == y).float()
        accuracy = torch.mean(correct) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        """Convert label indices to one-hot vector"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)

        out[torch.from_numpy(np.arange(batch_size).astype(np.int64)), labels.long()] = 1
        return out


    def train(self):
        """Train StarGAN within a single dataset."""
        self.criterionL1 = torch.nn.L1Loss()
        # self.criterionL2 = torch.nn.MSELoss()
        self.criterionTV = TVLoss()

        self.data_loader = self.Msceleb_loader
        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (aug_images, aug_labels, _, _) in enumerate(self.data_loader):
            fixed_x.append(aug_images)
            real_c.append(aug_labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (aug_x, aug_label, origin_x, origin_label) in enumerate(self.data_loader):

                # Generat fake labels randomly (target domain labels)
                # aug_c = self.one_hot(aug_label, self.c_dim)
                # origin_c = self.one_hot(origin_label, self.c_dim)

                aug_c_V = self.to_var(aug_label)
                origin_c_V = self.to_var(origin_label)

                aug_x = self.to_var(aug_x)
                origin_x = self.to_var(origin_x)

                # # ================== Train D ================== #
                # Compute loss with real images
                out_src = self.D(origin_x)
                out_cls = self.C(origin_x)
                d_loss_real = - torch.mean(out_src)

                c_loss_cls = F.cross_entropy(out_cls, origin_c_V)
                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, origin_c_V)
                    log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                    print('Classification Acc (75268 ids): ')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(aug_x)
                fake_x = Variable(fake_x.data)
                out_src = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake
                c_loss = self.lambda_cls * c_loss_cls


                self.reset_grad()
                d_loss.backward()
                c_loss.backward()
                self.d_optimizer.step()
                self.c_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(origin_x.size(0), 1, 1, 1).cuda().expand_as(origin_x)
                interpolated = Variable(alpha * origin_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]
                loss['C/loss_cls'] = c_loss_cls.data[0]

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(aug_x)

                    # Compute losses
                    out_src = self.D(fake_x)
                    out_cls = self.C(fake_x)
                    g_loss_fake = - torch.mean(out_src)

                    g_loss_cls = F.cross_entropy(out_cls, aug_c_V)

                    # Backward + Optimize
                    recon_loss = self.criterionL1(fake_x, aug_x)
                    TV_loss = self.criterionTV(fake_x) * 0.001

                    g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + 5* recon_loss + TV_loss

                    # if self.lambda_face > 0.0:
                    #     self.criterionFace = nn.L1Loss()
                    #
                    #     real_input_x = (torch.sum(real_x, 1, keepdim=True) / 3.0 + 1) / 2.0
                    #     fake_input_x = (torch.sum(fake_x, 1, keepdim=True) / 3.0 + 1) / 2.0
                    #     rec_input_x = (torch.sum(rec_x, 1, keepdim=True) / 3.0 + 1) / 2.0
                    #
                    #     _, real_x_feature_fc, real_x_feature_conv = self.Face_recognition_network.forward(
                    #         real_input_x)
                    #     _, fake_x_feature_fc, fake_x_feature_conv = self.Face_recognition_network.forward(
                    #         fake_input_x)
                    #     _, rec_x1_feature_fc, rec_x1_feature_conv = self.Face_recognition_network.forward(rec_input_x)
                    #     # x1_loss = (self.criterionFace(fake_x1_feature_fc, Variable(real_x1_feature_fc.data,requires_grad=False)) +
                    #     #            self.criterionFace(fake_x1_feature_conv,Variable(real_x1_feature_conv.data,requires_grad=False)))\
                    #     #            * self.lambda_face
                    #     x_loss = (self.criterionFace(fake_x_feature_fc,Variable(real_x_feature_fc.data, requires_grad=False))) \
                    #               * self.lambda_face
                    #
                    #     rec_x_loss = (self.criterionFace(rec_x1_feature_fc, Variable(real_x_feature_fc.data, requires_grad=False)))
                    #
                    #     self.id_loss = x_loss + rec_x_loss
                    #     loss['G/id_loss'] = self.id_loss.data[0]
                    #     g_loss += self.id_loss

                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)


                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]

                    fake_image_list.append(self.G(fixed_x))

                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

                # Save model checkpoints
                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))
                    torch.save(self.C.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_C.pth'.format(e+1, i+1)))


            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))


            torch.save(self.G.state_dict(),
                           os.path.join(self.model_save_path, '{}_final_G.pth'.format(e + 1)))
            torch.save(self.D.state_dict(),
                           os.path.join(self.model_save_path, '{}_final_D.pth'.format(e + 1)))
            torch.save(self.C.state_dict(),
                           os.path.join(self.model_save_path, '{}_final_C.pth'.format(e + 1)))
예제 #8
0
def main(args):

    save_folder = '%s_%s' % (args.dataset, args.affix)

    log_folder = os.path.join(args.log_root, save_folder)
    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(log_folder)
    makedirs(model_folder)

    setattr(args, 'log_folder', log_folder)
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, args.todo, 'info')

    print_args(args, logger)

    model = Classifier(10)

    attack = FastGradientSignUntargeted(model,
                                        args.epsilon,
                                        args.alpha,
                                        min_val=-1,
                                        max_val=1,
                                        max_iters=args.k,
                                        _type=args.perturbation_type)

    trainer = Trainer(args, logger, attack)

    if args.todo == 'train':
        if torch.cuda.is_available():
            model.cuda()
        tr_dataset = tv.datasets.MNIST(args.data_root,
                                       train=True,
                                       transform=tv.transforms.Compose([
                                           tv.transforms.Resize(args.img_size),
                                           tv.transforms.ToTensor(),
                                           tv.transforms.Normalize([0.5],
                                                                   [0.5])
                                       ]),
                                       download=True)

        tr_loader = DataLoader(tr_dataset,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=4)

        # evaluation during training
        te_dataset = tv.datasets.MNIST(args.data_root,
                                       train=False,
                                       transform=tv.transforms.Compose([
                                           tv.transforms.Resize(args.img_size),
                                           tv.transforms.ToTensor(),
                                           tv.transforms.Normalize([0.5],
                                                                   [0.5])
                                       ]),
                                       download=True)

        te_loader = DataLoader(te_dataset,
                               batch_size=args.batch_size,
                               shuffle=False,
                               num_workers=4)

        trainer.train(model, tr_loader, te_loader, args.adv_train)
    elif args.todo == 'valid':
        load_model(model, args.load_checkpoint, args)

        if torch.cuda.is_available():
            model.cuda()
        te_dataset = tv.datasets.MNIST(args.data_root,
                                       train=False,
                                       transform=tv.transforms.Compose([
                                           tv.transforms.Resize(args.img_size),
                                           tv.transforms.ToTensor(),
                                           tv.transforms.Normalize([0.5],
                                                                   [0.5])
                                       ]),
                                       download=True)

        te_loader = DataLoader(te_dataset,
                               batch_size=args.batch_size,
                               shuffle=False,
                               num_workers=4)

        test_acc, adv_acc = trainer.test(model, te_loader, adv_test=False)
        print('Test accuracy is %.3f' % test_acc)
    else:
        raise NotImplementedError
예제 #9
0
class Training(object):
    def __init__(self, config, logger=None):
        if logger is None:
            logger = logging.getLogger('logger')
            logger.setLevel(logging.DEBUG)
            logging.basicConfig(format='%(message)s', level=logging.DEBUG)

        self.logger = logger
        self.config = config
        self.classes = list(config.id2label.keys())
        self.num_classes = config.num_classes

        self.embedder = Embedder(self.config)
        self.encoder = LSTMEncoder(self.config)
        self.clf = Classifier(self.config)
        self.clf_loss = SequenceCriteria(class_weight=None)
        if self.config.lambda_ae > 0: self.ae = AEModel(self.config)

        self.writer = SummaryWriter(log_dir="TFBoardSummary")
        self.global_steps = 0
        self.enc_clf_opt = Adam(self._get_trainabe_modules(),
                                lr=self.config.lr,
                                betas=(config.beta1, config.beta2),
                                weight_decay=config.weight_decay,
                                eps=config.eps)

        if config.scheduler == "ReduceLROnPlateau":
            self.scheduler = lr_scheduler.ReduceLROnPlateau(
                self.enc_clf_opt,
                mode='max',
                factor=config.lr_decay,
                patience=config.patience,
                verbose=True)
        elif config.scheduler == "ExponentialLR":
            self.scheduler = lr_scheduler.ExponentialLR(self.enc_clf_opt,
                                                        gamma=config.gamma)

        self._init_or_load_model()
        if config.multi_gpu:
            self.embedder.cuda()
            self.encoder.cuda()
            self.clf.cuda()
            self.clf_loss.cuda()
            if self.config.lambda_ae > 0: self.ae.cuda()

        self.ema_embedder = ExponentialMovingAverage(decay=0.999)
        self.ema_embedder.register(self.embedder.state_dict())
        self.ema_encoder = ExponentialMovingAverage(decay=0.999)
        self.ema_encoder.register(self.encoder.state_dict())
        self.ema_clf = ExponentialMovingAverage(decay=0.999)
        self.ema_clf.register(self.clf.state_dict())

        self.time_s = time()

    def _get_trainabe_modules(self):
        param_list = list(self.embedder.parameters()) + \
                     list(self.encoder.parameters()) + \
                     list(self.clf.parameters())
        if self.config.lambda_ae > 0:
            param_list += list(self.ae.parameters())
        return param_list

    def _get_l2_norm_loss(self):
        total_norm = 0.
        for p in self._get_trainabe_modules():
            param_norm = p.data.norm(p=2)
            total_norm += param_norm  # ** 2
        return total_norm  # / 2.

    def _init_or_load_model(self):
        # if not self._load_model():
        ensure_directory(self.config.output_path)
        self.epoch = 0
        self.best_accuracy = -np.inf

    def _compute_vocab_freq(self, train_, dev_):
        counter = collections.Counter()
        for _, ids_ in train_:
            counter.update(ids_)
        for _, ids_ in dev_:
            counter.update(ids_)
        word_freq = np.zeros(self.config.n_vocab)
        for index_, freq_ in counter.items():
            word_freq[index_] = freq_
        return torch.from_numpy(word_freq).type(batch_utils.FLOAT_TYPE)

    def _save_model(self):
        state = {
            'epoch': self.epoch,
            'state_dict_encoder': self.ema_encoder.shadow_variable_dict,
            #                  self.encoder.state_dict(),
            'state_dict_embedder': self.ema_embedder.shadow_variable_dict,
            # self.embedder.state_dict(),
            'state_dict_clf': self.ema_clf.shadow_variable_dict,
            # self.clf.state_dict(),
            'best_accuracy': self.best_accuracy
        }
        torch.save(
            state, os.path.join(self.config.output_path,
                                self.config.model_file))

    def _load_model(self):
        checkpoint_path = os.path.join(self.config.output_path,
                                       self.config.model_file)
        if self.config.load_checkpoint and os.path.isfile(checkpoint_path):
            # Code taken from here: https://github.com/pytorch/examples/blob/master/imagenet/main.py
            dict_ = torch.load(checkpoint_path)
            self.epoch = dict_['epoch']
            self.best_accuracy = dict_['best_accuracy']
            self.embedder.load_state_dict(dict_['state_dict_embedder'])
            self.encoder.load_state_dict(dict_['state_dict_encoder'])
            self.clf.load_state_dict(dict_['state_dict_clf'])
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_path, self.epoch))
            return True

    def __call__(self, train, dev, test, unlabel, addn, addn_un, addn_test):
        self.logger.info('Start training')
        self._train(train, dev, unlabel, addn, addn_un, addn_test)
        self._evaluate(test, addn_test)

    def _create_iter(self, data_, wbatchsize, random_shuffler=None):
        iter_ = data.iterator.pool(data_,
                                   wbatchsize,
                                   key=lambda x: len(x[1]),
                                   batch_size_fn=batch_size_fn,
                                   random_shuffler=None)
        return iter_

    def _run_epoch(self, train_data, dev_data, unlabel_data, addn_data,
                   addn_data_unlab, addn_dev):
        addn_dev.cuda()
        report_stats = utils.Statistics()
        cm = ConfusionMatrix(self.classes)
        _, seq_data = list(zip(*train_data))
        total_seq_words = len(list(itertools.chain.from_iterable(seq_data)))
        iter_per_epoch = (1.5 * total_seq_words) // self.config.wbatchsize

        self.encoder.train()
        self.clf.train()
        self.embedder.train()
        #         print(addn_data)
        #         print(addn_data.shape)
        #         print(train_data[:5])
        train_iter = self._create_iter(train_data, self.config.wbatchsize)
        #         addn_iter = self._create_iter(addn_data, self.config.wbatchsize)
        #         train_iter = self._create_iter(zip(train_data, addn_data), self.config.wbatchsize)
        unlabel_iter = self._create_iter(unlabel_data,
                                         self.config.wbatchsize_unlabel)

        sofar = 0
        sofar_1 = 0
        for batch_index, train_batch_raw in enumerate(train_iter):
            seq_iter = list(zip(*train_batch_raw))[1]
            seq_words = len(list(itertools.chain.from_iterable(seq_iter)))
            report_stats.n_words += seq_words
            self.global_steps += 1

            # self.enc_clf_opt.zero_grad()
            if self.config.add_noise:
                train_batch_raw = add_noise(train_batch_raw,
                                            self.config.noise_dropout,
                                            self.config.random_permutation)
            train_batch = batch_utils.seq_pad_concat(train_batch_raw, -1)

            #             print(train_batch.shape)

            train_embedded = self.embedder(train_batch)
            memory_bank_train, enc_final_train = self.encoder(
                train_embedded, train_batch)

            if self.config.lambda_vat > 0 or self.config.lambda_ae > 0 or self.config.lambda_entropy:
                try:
                    unlabel_batch_raw = next(unlabel_iter)
                except StopIteration:
                    unlabel_iter = self._create_iter(
                        unlabel_data, self.config.wbatchsize_unlabel)
                    unlabel_batch_raw = next(unlabel_iter)

                if self.config.add_noise:
                    unlabel_batch_raw = add_noise(
                        unlabel_batch_raw, self.config.noise_dropout,
                        self.config.random_permutation)
                unlabel_batch = batch_utils.seq_pad_concat(
                    unlabel_batch_raw, -1)
                unlabel_embedded = self.embedder(unlabel_batch)
                memory_bank_unlabel, enc_final_unlabel = self.encoder(
                    unlabel_embedded, unlabel_batch)

#             print(memory_bank_unlabel.shape[0])
            addn_batch_unlab = retAddnBatch(addn_data_unlab,
                                            memory_bank_unlabel.shape[0],
                                            sofar_1).cuda()
            sofar_1 += addn_batch_unlab.shape[0]
            #             print(addn_batch_unlab.shape)

            #             print(memory_bank_train.shape[0])
            addn_batch = retAddnBatch(addn_data, memory_bank_train.shape[0],
                                      sofar).cuda()
            sofar += addn_batch.shape[0]
            #             print(addn_batch.shape)
            pred = self.clf(memory_bank_train, addn_batch, enc_final_train)
            #             print(pred)
            accuracy = self.get_accuracy(cm, pred.data,
                                         train_batch.labels.data)
            lclf = self.clf_loss(pred, train_batch.labels)

            lat = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            lvat = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_at > 0:
                lat = at_loss(
                    self.embedder,
                    self.encoder,
                    self.clf,
                    train_batch,
                    addn_batch,
                    perturb_norm_length=self.config.perturb_norm_length)

            if self.config.lambda_vat > 0:
                lvat_train = vat_loss(
                    self.embedder,
                    self.encoder,
                    self.clf,
                    train_batch,
                    addn_batch,
                    p_logit=pred,
                    perturb_norm_length=self.config.perturb_norm_length)
                if self.config.inc_unlabeled_loss:
                    lvat_unlabel = vat_loss(
                        self.embedder,
                        self.encoder,
                        self.clf,
                        unlabel_batch,
                        addn_batch_unlab,
                        p_logit=self.clf(memory_bank_unlabel, addn_batch_unlab,
                                         enc_final_unlabel),
                        perturb_norm_length=self.config.perturb_norm_length)
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lvat = 0.5 * (lvat_train + lvat_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lvat = lvat_unlabel
                else:
                    lvat = lvat_train

            lentropy = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_entropy > 0:
                lentropy_train = entropy_loss(pred)
                if self.config.inc_unlabeled_loss:
                    lentropy_unlabel = entropy_loss(
                        self.clf(memory_bank_unlabel, addn_batch_unlab,
                                 enc_final_unlabel))
                    if self.config.unlabeled_loss_type == "AvgTrainUnlabel":
                        lentropy = 0.5 * (lentropy_train + lentropy_unlabel)
                    elif self.config.unlabeled_loss_type == "Unlabel":
                        lentropy = lentropy_unlabel
                else:
                    lentropy = lentropy_train

            lae = Variable(
                torch.FloatTensor([-1.]).type(batch_utils.FLOAT_TYPE))
            if self.config.lambda_ae > 0:
                lae = self.ae(memory_bank_unlabel, enc_final_unlabel,
                              unlabel_batch.sent_len, unlabel_batch_raw)

            ltotal = (self.config.lambda_clf * lclf) + \
                     (self.config.lambda_ae * lae) + \
                     (self.config.lambda_at * lat) + \
                     (self.config.lambda_vat * lvat) + \
                     (self.config.lambda_entropy * lentropy)

            report_stats.clf_loss += lclf.data.cpu().numpy()
            report_stats.at_loss += lat.data.cpu().numpy()
            report_stats.vat_loss += lvat.data.cpu().numpy()
            report_stats.ae_loss += lae.data.cpu().numpy()
            report_stats.entropy_loss += lentropy.data.cpu().numpy()
            report_stats.n_sent += len(pred)
            report_stats.n_correct += accuracy
            self.enc_clf_opt.zero_grad()
            ltotal.backward()

            params_list = self._get_trainabe_modules()
            # Excluding embedder form norm constraint when AT or VAT
            if not self.config.normalize_embedding:
                params_list += list(self.embedder.parameters())

            norm = torch.nn.utils.clip_grad_norm(params_list,
                                                 self.config.max_norm)
            report_stats.grad_norm += norm
            self.enc_clf_opt.step()
            if self.config.scheduler == "ExponentialLR":
                self.scheduler.step()
            self.ema_embedder.apply(self.embedder.named_parameters())
            self.ema_encoder.apply(self.encoder.named_parameters())
            self.ema_clf.apply(self.clf.named_parameters())

            report_func(self.epoch, batch_index, iter_per_epoch, self.time_s,
                        report_stats, self.config.report_every, self.logger)

            if self.global_steps % self.config.eval_steps == 0:
                cm_, accuracy, prc_dev = self._run_evaluate(dev_data, addn_dev)
                self.logger.info(
                    "- dev accuracy {} | best dev accuracy {} ".format(
                        accuracy, self.best_accuracy))
                self.writer.add_scalar("Dev_Accuracy", accuracy,
                                       self.global_steps)
                pred_, lab_ = zip(*prc_dev)
                pred_ = torch.cat(pred_)
                lab_ = torch.cat(lab_)
                self.writer.add_pr_curve("Dev PR-Curve", lab_, pred_,
                                         self.global_steps)
                pprint.pprint(cm_)
                pprint.pprint(cm_.get_all_metrics())
                if accuracy > self.best_accuracy:
                    self.logger.info("- new best score!")
                    self.best_accuracy = accuracy
                    self._save_model()
                if self.config.scheduler == "ReduceLROnPlateau":
                    self.scheduler.step(accuracy)
                self.encoder.train()
                self.embedder.train()
                self.clf.train()

                if self.config.weight_decay > 0:
                    print(">> Square Norm: %1.4f " % self._get_l2_norm_loss())

        cm, train_accuracy, _ = self._run_evaluate(train_data, addn_data)
        self.logger.info("- Train accuracy  {}".format(train_accuracy))
        pprint.pprint(cm.get_all_metrics())

        cm, dev_accuracy, _ = self._run_evaluate(dev_data, addn_dev)
        self.logger.info("- Dev accuracy  {} | best dev accuracy {}".format(
            dev_accuracy, self.best_accuracy))
        pprint.pprint(cm.get_all_metrics())
        self.writer.add_scalars("Overall_Accuracy", {
            "Train_Accuracy": train_accuracy,
            "Dev_Accuracy": dev_accuracy
        }, self.global_steps)
        return dev_accuracy

    @staticmethod
    def get_accuracy(cm, output, target):
        batch_size = output.size(0)
        predictions = output.max(-1)[1].type_as(target)
        correct = predictions.eq(target)
        correct = correct.float()
        if not hasattr(correct, 'sum'):
            correct = correct.cpu()
        correct = correct.sum()
        cm.add_batch(target.cpu().numpy(), predictions.cpu().numpy())
        return correct

    def _predict_batch(self, cm, batch, addn_batch):
        self.embedder.eval()
        self.encoder.eval()
        self.clf.eval()
        one, two = self.encoder(self.embedder(batch), batch)
        pred = self.clf(one, addn_batch, two)
        accuracy = self.get_accuracy(cm, pred.data, batch.labels.data)
        return pred, accuracy

    def chunks(self, l, n=15):
        """Yield successive n-sized chunks from l."""
        for i in range(0, len(l), n):
            yield l[i:i + n]

    def _run_evaluate(self, test_data, addn_test):
        pr_curve_data = []
        cm = ConfusionMatrix(self.classes)
        accuracy_list = []
        # test_iter = self._create_iter(test_data, self.config.wbatchsize,
        #                               random_shuffler=utils.identity_fun)
        test_iter = self.chunks(test_data)

        for batch_index, test_batch in enumerate(test_iter):
            addn_batch = addn_test[batch_index * 15:(batch_index + 1) * 15]
            test_batch = batch_utils.seq_pad_concat(test_batch, -1)
            #             print(addn_batch.shape)
            try:
                pred, acc = self._predict_batch(cm, test_batch, addn_batch)
            except:
                continue
            accuracy_list.append(acc)
            pr_curve_data.append(
                (F.softmax(pred, -1)[:, 1].data, test_batch.labels.data))
        accuracy = 100 * (sum(accuracy_list) / len(test_data))
        return cm, accuracy, pr_curve_data

    def _train(self, train_data, dev_data, unlabel_data, addn_data,
               addn_data_unlab, addn_dev):
        addn_data = addn_data.cuda()
        addn_data_unlab = addn_data_unlab.cuda()
        addn_dev = addn_dev.cuda()
        # for early stopping
        nepoch_no_imprv = 0

        epoch_start = self.epoch + 1
        epoch_end = self.epoch + self.config.nepochs + 1
        for self.epoch in range(epoch_start, epoch_end):
            self.logger.info("Epoch {:} out of {:}".format(
                self.epoch, self.config.nepochs))
            #             random.shuffle(train_data)
            #             random.shuffle(unlabel_data)
            accuracy = self._run_epoch(train_data, dev_data, unlabel_data,
                                       addn_data, addn_data_unlab, addn_dev)

            # early stopping and saving best parameters
            if accuracy > self.best_accuracy:
                nepoch_no_imprv = 0
                self.best_accuracy = accuracy
                self.logger.info("- new best score!")
                self._save_model()
            else:
                nepoch_no_imprv += 1
                if nepoch_no_imprv >= self.config.nepoch_no_imprv:
                    self.logger.info(
                        "- early stopping {} epochs without improvement".
                        format(nepoch_no_imprv))
                    break
            if self.config.scheduler == "ReduceLROnPlateau":
                self.scheduler.step(accuracy)

    def _evaluate(self, test_data, addn_test):
        addn_test = addn_test.cuda()
        self.logger.info("Evaluating model over test set")
        self._load_model()
        _, accuracy, prc_test = self._run_evaluate(test_data, addn_test)
        pred_, lab_ = zip(*prc_test)
        pred_ = torch.cat(pred_).cpu().tolist()
        lab_ = torch.cat(lab_).cpu().tolist()
        path_ = os.path.join(self.config.output_path,
                             "{}_pred_gt.tsv".format(self.config.exp_name))
        with open(path_, 'w') as fp:
            for p, l in zip(pred_, lab_):
                fp.write(str(p) + '\t' + str(l) + '\n')
        self.logger.info("- test accuracy {}".format(accuracy))
예제 #10
0
def main():
    resume_ep = opt.r
    train_dir = opt.train_dir
    check_dir = opt.check_dir
    val_dir = opt.val_dir

    bsize = opt.b
    iter_num = opt.e

    label_weight = [4.858, 17.57]
    std = [.229, .224, .225]
    mean = [.485, .456, .406]

    os.system('rm -rf ./runs2/*')
    writer = SummaryWriter('./runs2/' +
                           datetime.now().strftime('%B%d  %H:%M:%S'))

    if not os.path.exists('./runs2'):
        os.mkdir('./runs2')

    if not os.path.exists(check_dir):
        os.mkdir(check_dir)

    # models
    if 'vgg' == opt.i:
        feature = Vgg16(pretrained=True)
    elif 'resnet' == opt.i:
        feature = resnet50(pretrained=True)
    elif 'densenet' == opt.i:
        feature = densenet121(pretrained=True)
    feature.cuda()

    classifier = Classifier(opt.i)
    classifier.cuda()

    if resume_ep >= 0:
        feature_param_file = glob.glob('%s/feature-epoch-%d*.pth' %
                                       (check_dir, resume_ep))
        classifier_param_file = glob.glob('%s/classifier-epoch-%d*.pth' %
                                          (check_dir, resume_ep))
        feature.load_state_dict(torch.load(feature_param_file[0]))
        classifier.load_state_dict(torch.load(classifier_param_file[0]))

    train_loader = torch.utils.data.DataLoader(MyClsData(train_dir,
                                                         transform=True,
                                                         crop=True,
                                                         hflip=True,
                                                         vflip=False),
                                               batch_size=bsize,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MyClsTestData(val_dir,
                                                           transform=True),
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)

    criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(label_weight))
    criterion.cuda()

    optimizer_classifier = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    optimizer_feature = torch.optim.Adam(feature.parameters(), lr=1e-4)

    acc = 0.0
    for it in range(resume_ep + 1, iter_num):
        for ib, (data, lbl) in enumerate(train_loader):
            inputs = Variable(data.float()).cuda()
            lbl = Variable(lbl.long()).cuda()
            feats = feature(inputs)

            output = classifier(feats)
            loss = criterion(output, lbl)

            classifier.zero_grad()
            feature.zero_grad()

            loss.backward()

            optimizer_feature.step()
            optimizer_classifier.step()
            if ib % 20 == 0:
                # image = make_image_grid(inputs.data[:4, :3], mean, std)
                # writer.add_image('Image', torchvision.utils.make_grid(image), ib)
                writer.add_scalar('M_global', loss.data[0], ib)
            print('loss: %.4f (epoch: %d, step: %d), acc: %.4f' %
                  (loss.data[0], it, ib, acc))
            del inputs, lbl, loss, feats
            gc.collect()
        new_acc = eval_acc(feature, classifier, val_loader)
        if new_acc > acc:
            filename = ('%s/classifier-epoch-%d-step-%d.pth' %
                        (check_dir, it, ib))
            torch.save(classifier.state_dict(), filename)
            filename = ('%s/feature-epoch-%d-step-%d.pth' %
                        (check_dir, it, ib))
            torch.save(feature.state_dict(), filename)
            print('save: (epoch: %d, step: %d)' % (it, ib))
            acc = new_acc
예제 #11
0
	





network = Classifier(VOCAB_SIZE,300,50,2)


loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(network.parameters(), lr=0.3)
count=0
in_count = 0
cudnn.benchmark = True
loss_function = loss_function.cuda()
network.cuda()
network.embeddings.weight.data.copy_(torch.from_numpy(mat))
for epoch in range(100):
	for lyrics, label in data:

		hidden = network.init_hidden()
		#import pdb;pdb.set_trace()
		network.zero_grad()
		bag_of_words = Variable(torch.from_numpy(lyrics).long())
		mood = Variable(make_target(label, label_to_index))
		
		bag_of_words.data ,mood.data= bag_of_words.data.cuda(), mood.data.cuda()
		hidden[0].data,hidden[1].data = hidden[0].data.cuda(),hidden[1].data.cuda()

		mood_pred, hidden = network(bag_of_words, hidden)
		
예제 #12
0
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    dataloaders.append([train_loader, test_loader])

## model and optimizer instantiations:
net = Classifier(image_size=args.im_size,
                 output_shape=60,
                 tasks=50,
                 layer_size=args.hidden_size,
                 bn_boole=True)
if gpu_boole:
    net = net.cuda()
# optimizer = torch.optim.Adam(net.parameters(), lr = 1e-4)
if args.lr_adj == -1:
    args.lr_adj = args.lr
optimizer = torch.optim.Adam([{
    'params':
    (param for name, param in net.named_parameters() if 'adjx' not in name),
    'lr':
    args.lr,
    'momentum':
    0
}, {
    'params':
    (param for name, param in net.named_parameters() if 'adjx' in name),
    'lr':
    args.lr_adj,
예제 #13
0
class ModelBuilder(object):
    def __init__(self, use_cuda, conf, model_name):
        self.cuda = use_cuda
        self.conf = conf
        self.model_name = model_name
        self._init_log()
        self._pre_data()
        self._build_model()

    def _pre_data(self):
        print('pre data...')
        self.data = Data(self.cuda, self.conf)
        self.spacy = spacy.load('en')
        # print('pre train SenImg pickle...')
        # self.img_pickle_train = self._load_text_img_pickle_all('train')
        # print('pre dev SenImg pickle...')
        # self.img_pickle_dev = self._load_text_img_pickle_all('dev')
        # print('pre test SenImg pickle...')
        # self.img_pickle_test = self._load_text_img_pickle_all('test')

    def _init_log(self):
        if self.conf.four_or_eleven == 2:
            filename = 'logs/train_' + datetime.now().strftime(
                '%B%d-%H_%M_%S'
            ) + '_' + self.model_name + self.conf.type + '_' + self.conf.i2senseclass[
                self.conf.binclass]
        else:
            filename = 'logs/train_' + datetime.now().strftime(
                '%B%d-%H_%M_%S') + '_' + self.model_name + '_' + self.conf.type

        if self.conf.need_elmo:
            filename += '_ELMO'

        logging.basicConfig(filename=filename + '.log',
                            filemode='a',
                            format='%(asctime)s - %(levelname)s: %(message)s',
                            level=logging.DEBUG)

    def _build_model(self):
        print('loading embedding...')
        if self.conf.corpus_splitting == 1:
            pre = './data/processed/lin/'
        elif self.conf.corpus_splitting == 2:
            pre = './data/processed/ji/'
        elif self.conf.corpus_splitting == 3:
            pre = './data/processed/l/'
        we = torch.load(pre + 'we.pkl')
        char_table = None
        sub_table = None
        if self.conf.need_char or self.conf.need_elmo:
            char_table = torch.load(pre + 'char_table.pkl')
        if self.conf.need_sub:
            sub_table = torch.load(pre + 'sub_table.pkl')
        print('building model...')
        if self.model_name == 'ArgSenImg':
            self.encoder = ArgEncoderSentImg2(self.conf, we, char_table,
                                              sub_table, self.cuda, None,
                                              self.spacy)
        elif self.model_name == 'ArgPhrImg':
            self.encoder = ArgEncoderPhrImg(self.conf, we, char_table,
                                            sub_table, self.cuda, None,
                                            self.spacy)
        elif self.model_name == 'ArgImgSelf':
            self.encoder = ArgEncoderImgSelf(self.conf, we, char_table,
                                             sub_table, self.cuda, None,
                                             self.spacy)
        else:
            self.encoder = ArgEncoder(self.conf, we, char_table, sub_table,
                                      self.cuda)
        self.classifier = Classifier(self.conf.clf_class_num, self.conf)
        if self.conf.is_mttrain:
            self.conn_classifier = Classifier(self.conf.conn_num, self.conf)
        if self.cuda:
            self.encoder.cuda()
            self.classifier.cuda()
            if self.conf.is_mttrain:
                self.conn_classifier.cuda()

        self.criterion = torch.nn.CrossEntropyLoss()
        para_filter = lambda model: filter(lambda p: p.requires_grad,
                                           model.parameters())
        self.e_optimizer = torch.optim.Adagrad(
            para_filter(self.encoder),
            self.conf.lr,
            weight_decay=self.conf.l2_penalty)
        self.c_optimizer = torch.optim.Adagrad(
            para_filter(self.classifier),
            self.conf.lr,
            weight_decay=self.conf.l2_penalty)
        if self.conf.is_mttrain:
            self.con_optimizer = torch.optim.Adagrad(
                para_filter(self.conn_classifier),
                self.conf.lr,
                weight_decay=self.conf.l2_penalty)

    def _print_train(self, epoch, time, loss, acc):
        print('-' * 80)
        print(
            '| end of epoch {:3d} | time: {:5.2f}s | loss: {:10.5f} | acc: {:5.2f}% |'
            .format(epoch, time, loss, acc * 100))
        print('-' * 80)
        logging.debug('-' * 80)
        logging.debug(
            '| end of epoch {:3d} | time: {:5.2f}s | loss: {:10.5f} | acc: {:5.2f}% |'
            .format(epoch, time, loss, acc * 100))
        logging.debug('-' * 80)

    def _print_eval(self, task, loss, acc, f1):
        print('| ' + task +
              ' loss {:10.5f} | acc {:5.2f}% | f1 {:5.2f}%'.format(
                  loss, acc * 100, f1 * 100))
        print('-' * 80)
        logging.debug('| ' + task +
                      ' loss {:10.5f} | acc {:5.2f}% | f1 {:5.2f}%'.format(
                          loss, acc * 100, f1 * 100))
        logging.debug('-' * 80)

    def _save_model(self, model, filename):
        torch.save(model.state_dict(), './weights/' + filename)

    def _load_model(self, model, filename):
        model.load_state_dict(torch.load('./weights/' + filename))

    def _train_one(self):
        self.encoder.train()
        self.classifier.train()
        if self.conf.is_mttrain:
            self.conn_classifier.train()
        total_loss = 0
        correct_n = 0
        train_size = self.data.train_size
        for i, (a1, a2, sense, conn, arg1_sen,
                arg2_sen) in enumerate(self.data.train_loader):
            try:
                start_time = time.time()
                if self.conf.four_or_eleven == 2:
                    mask1 = (sense == self.conf.binclass)
                    mask2 = (sense != self.conf.binclass)
                    sense[mask1] = 1
                    sense[mask2] = 0
                if self.cuda:
                    a1, a2, sense, conn = a1.cuda(), a2.cuda(), sense.cuda(
                    ), conn.cuda()
                a1, a2, sense, conn = Variable(a1), Variable(a2), Variable(
                    sense), Variable(conn)
                if self.model_name in [
                        'ArgImg', 'ArgSenImg', 'ArgPhrImg', 'ArgImgSelf'
                ]:
                    self._load_text_img_pickle_index(i)
                    # img_pickle = self.img_pickle_train[i]
                    repr = self.encoder(a1, a2, arg1_sen, arg2_sen,
                                        self.text_pkl, self.img_pkl,
                                        self.phrase_text_pkl,
                                        self.phrase_img_pkl, i, 'train')
                else:
                    repr = self.encoder(a1, a2)
                output = self.classifier(repr)
                _, output_sense = torch.max(output, 1)
                assert output_sense.size() == sense.size()
                tmp = (output_sense == sense).long()
                correct_n += torch.sum(tmp).data
                loss = self.criterion(output, sense)

                if self.conf.is_mttrain:
                    conn_output = self.conn_classifier(repr)
                    loss2 = self.criterion(conn_output, conn)
                    loss = loss + loss2 * self.conf.lambda1

                self.e_optimizer.zero_grad()
                self.c_optimizer.zero_grad()
                if self.conf.is_mttrain:
                    self.con_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(self.encoder.parameters(),
                                              self.conf.grad_clip)
                torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                              self.conf.grad_clip)
                if self.conf.is_mttrain:
                    torch.nn.utils.clip_grad_norm(
                        self.conn_classifier.parameters(), self.conf.grad_clip)
                self.e_optimizer.step()
                self.c_optimizer.step()
                if self.conf.is_mttrain:
                    self.con_optimizer.step()

                total_loss += loss.data * sense.size(0)
                if self.model_name in ['ArgImg', 'ArgSenImg', 'ArgPhrImg']:
                    print('==================' + str(i) + '==================')
                    print('total_loss:' + str(total_loss[0] / (len(arg1_sen) *
                                                               (i + 1))) +
                          ' acc:' + str(correct_n[0].float() /
                                        (len(arg1_sen) * (i + 1))) + ' time:' +
                          str(time.time() - start_time))
                    logging.debug('==================' + str(i) +
                                  '==================')
                    logging.debug('total_loss:' + str(total_loss[0] /
                                                      (len(arg1_sen) *
                                                       (i + 1))) + ' acc:' +
                                  str(correct_n[0].float() /
                                      (len(arg1_sen) * (i + 1))) + ' time:' +
                                  str(time.time() - start_time))
            except Exception as e:
                print(e)
                logging.debug(e)
                continue

        return total_loss[0] / train_size, correct_n[0].float() / train_size

    def _train(self, pre):
        for epoch in range(self.conf.epochs):
            start_time = time.time()
            loss, acc = self._train_one()
            self._print_train(epoch, time.time() - start_time, loss, acc)
            self.logwriter.add_scalar('loss/train_loss', loss, epoch)
            self.logwriter.add_scalar('acc/train_acc', acc * 100, epoch)

            dev_loss, dev_acc, dev_f1 = self._eval('dev')
            self._print_eval('dev', dev_loss, dev_acc, dev_f1)
            self.logwriter.add_scalar('loss/dev_loss', dev_loss, epoch)
            self.logwriter.add_scalar('acc/dev_acc', dev_acc * 100, epoch)
            self.logwriter.add_scalar('f1/dev_f1', dev_f1 * 100, epoch)

            test_loss, test_acc, test_f1 = self._eval('test')
            self._print_eval('test', test_loss, test_acc, test_f1)
            self.logwriter.add_scalar('loss/test_loss', test_loss, epoch)
            self.logwriter.add_scalar('acc/test_acc', test_acc * 100, epoch)
            self.logwriter.add_scalar('f1/test_f1', test_f1 * 100, epoch)

    def train(self, pre):
        print('start training')
        logging.debug('start training')
        self.logwriter = SummaryWriter(self.conf.logdir)
        self._train(pre)
        self._save_model(self.encoder, pre + '_eparams.pkl')
        self._save_model(self.classifier, pre + '_cparams.pkl')
        print('training done')
        logging.debug('training done')

    def _eval(self, task):
        self.encoder.eval()
        self.classifier.eval()
        total_loss = 0
        correct_n = 0
        if task == 'dev':
            data = self.data.dev_loader
            n = self.data.dev_size
        elif task == 'test':
            data = self.data.test_loader
            n = self.data.test_size
        else:
            raise Exception('wrong eval task')
        output_list = []
        gold_list = []
        for i, (a1, a2, sense1, sense2, arg1_sen, arg2_sen) in enumerate(data):
            try:
                if self.conf.four_or_eleven == 2:
                    mask1 = (sense1 == self.conf.binclass)
                    mask2 = (sense1 != self.conf.binclass)
                    sense1[mask1] = 1
                    sense1[mask2] = 0
                    mask0 = (sense2 == -1)
                    mask1 = (sense2 == self.conf.binclass)
                    mask2 = (sense2 != self.conf.binclass)
                    sense2[mask1] = 1
                    sense2[mask2] = 0
                    sense2[mask0] = -1
                if self.cuda:
                    a1, a2, sense1, sense2 = a1.cuda(), a2.cuda(), sense1.cuda(
                    ), sense2.cuda()
                a1 = Variable(a1, volatile=True)
                a2 = Variable(a2, volatile=True)
                sense1 = Variable(sense1, volatile=True)
                sense2 = Variable(sense2, volatile=True)

                if self.model_name in [
                        'ArgImg', 'ArgSenImg', 'ArgPhrImg', 'ArgImgSelf'
                ]:
                    # self._load_text_img_pickle_index(i)
                    # if task == 'dev':
                    #     img_pickle = self.img_pickle_dev[i]
                    # else:
                    #     img_pickle = self.img_pickle_test[i]
                    self._load_text_img_pickle_index(i)
                    # img_pickle = self.img_pkl
                    output = self.classifier(
                        self.encoder(a1,
                                     a2,
                                     arg1_sen,
                                     arg2_sen,
                                     self.text_pkl,
                                     self.img_pkl,
                                     self.phrase_text_pkl,
                                     self.phrase_img_pkl,
                                     i,
                                     task=task))
                else:
                    output = self.classifier(self.encoder(a1, a2))
                _, output_sense = torch.max(output, 1)
                assert output_sense.size() == sense1.size()
                gold_sense = sense1
                mask = (output_sense == sense2)
                gold_sense[mask] = sense2[mask]
                tmp = (output_sense == gold_sense).long()
                correct_n += torch.sum(tmp).data

                output_list.append(output_sense)
                gold_list.append(gold_sense)

                loss = self.criterion(output, gold_sense)
                total_loss += loss.data * gold_sense.size(0)

                output_s = torch.cat(output_list)
                gold_s = torch.cat(gold_list)
                if self.conf.four_or_eleven == 2:
                    f1 = f1_score(gold_s.cpu().data.numpy(),
                                  output_s.cpu().data.numpy(),
                                  average='binary')
                else:
                    f1 = f1_score(gold_s.cpu().data.numpy(),
                                  output_s.cpu().data.numpy(),
                                  average='macro')
            except Exception as e:
                print(e)
                logging.debug(e)
                continue

        return total_loss[0] / n, correct_n[0].float() / n, f1

    def eval(self, pre):
        print('evaluating...')
        logging.debug('evaluating...')
        self._load_model(self.encoder, pre + '_eparams.pkl')
        self._load_model(self.classifier, pre + '_cparams.pkl')
        test_loss, test_acc, f1 = self._eval('test')
        self._print_eval('test', test_loss, test_acc, f1)

    def _load_text_img_pickle_all(self, task='train'):
        img_pickle = []
        if task == 'dev':
            data = self.data.dev_loader
            n = self.data.dev_size
        elif task == 'test':
            data = self.data.test_loader
            n = self.data.test_size
        else:
            data = self.data.train_loader
            n = self.data.train_loader

        for i, (a1, a2, sense1, sense2, arg1_sen, arg2_sen) in enumerate(data):
            self._load_text_img_pickle_index(i, task)
            img_pickle.append(self.img_pkl)

        return np.array(img_pickle)

    def _load_text_img_pickle_index(self, index, task='train'):
        root_dir = '/home/wangjian/projects/RNNImageIDRR/data/text_img'
        if task != 'train':
            text_pkl_path = root_dir + '/text_' + task + '_' + str(
                index) + '.pkl'
            img_pkl_path = root_dir + '/img_' + task + '_' + str(
                index) + '.pkl'
            phrase_text_pkl_path = root_dir + '/phrase_text_' + task + '_' + str(
                index) + '.pkl'
            phrase_img_pkl_path = root_dir + '/phrase_img_' + task + '_' + str(
                index) + '.pkl'
        else:
            text_pkl_path = root_dir + '/text_' + str(index) + '.pkl'
            img_pkl_path = root_dir + '/img_' + str(index) + '.pkl'
            phrase_text_pkl_path = root_dir + '/phrase_text_' + str(
                index) + '.pkl'
            phrase_img_pkl_path = root_dir + '/phrase_img_' + str(
                index) + '.pkl'
        self.text_pkl = []
        self.img_pkl = []
        self.phrase_text_pkl = []
        self.phrase_img_pkl = []
        if self.model_name in [
                'ArgImg', 'ArgSenImg', 'ArgPhrImg', 'ArgImgSelf'
        ]:
            if self.model_name in ['ArgImg', 'ArgSenImg', 'ArgImgSelf']:
                if os.path.exists(text_pkl_path) and os.path.exists(
                        img_pkl_path):
                    # print(img_pkl_path)
                    # logging.debug(img_pkl_path)
                    # with open(text_pkl_path, 'rb') as f:
                    #     try:
                    #         while True:
                    #             self.text_pkl.append(pickle.load(f))
                    #     except:
                    #         pass
                    with h5py.File(img_pkl_path) as f:
                        img = f['img_features'][:]
                        img = img.reshape((len(img), 3, 256, 256))
                        self.img_pkl.extend(img)

            if self.model_name in ['ArgImg', 'ArgPhrImg', 'ArgImgSelf']:
                if os.path.exists(phrase_text_pkl_path) and os.path.exists(
                        phrase_img_pkl_path):
                    with open(phrase_text_pkl_path, 'rb') as f:
                        try:
                            while True:
                                self.phrase_text_pkl.append(pickle.load(f))
                        except:
                            pass
                    with h5py.File(phrase_img_pkl_path) as f:
                        img = f['img_features'][:]
                        img = img.reshape((len(img), 3, 256, 256))
                        self.phrase_img_pkl.extend(img)
예제 #14
0
def main():
    ####################argument parser#################################
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path',
                        type=str,
                        default="",
                        required=True,
                        help='Location of current config file')
    parser.add_argument('--dataset_csvpath',
                        type=str,
                        required=True,
                        default="./",
                        help='Location to data csv file')

    parser.add_argument("--wandbkey",
                        type=str,
                        default='2d5e5aa07e2a9cd4f84004f838566b5eca9f5856',
                        help='Wandb project key')
    parser.add_argument("--wandbproject",
                        type=str,
                        required=True,
                        default='',
                        help='wandb project name')
    parser.add_argument("--wandbexperiment",
                        type=str,
                        required=True,
                        default='',
                        help='wandb experiment name')

    parser.add_argument("--ckpt_save_dir",
                        type=str,
                        default='./ckpts',
                        help='path to save checkpoints')

    parser.add_argument(
        "--resume_checkpoint_path",
        type=str,
        default='',
        help='If you want to resume training enter path to checkpoint')

    #####################read config file###############################
    args = parser.parse_args()
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.config_path)
    cfg.merge_from_list([
        'train.config_path', args.config_path, 'dataset.csvpath',
        args.dataset_csvpath, 'train.ckpt_save_dir', args.ckpt_save_dir
    ])
    cfg.freeze()
    print(cfg)

    ####### Wandb
    os.makedirs(args.ckpt_save_dir, exist_ok=True)
    os.system('wandb login {}'.format(args.wandbkey))
    wandb.init(name=args.wandbexperiment,
               project=args.wandbproject,
               config=cfg)
    wandb.save(args.config_path)  # Save configuration file on wandb

    validate = False
    if (os.path.exists(os.path.join(args.dataset_csvpath, 'valid.csv'))):
        validate = True
    train_object = classDataset(cfg, 'train')
    train_loader = DataLoader(train_object,
                              batch_size=cfg.dataset.batch_size_pergpu *
                              len(cfg.train.gpus),
                              shuffle=cfg.dataset.shuffle,
                              num_workers=cfg.dataset.num_workers)

    if (validate):
        valid_object = classDataset(cfg, 'valid')
        valid_loader = DataLoader(valid_object,
                                  batch_size=cfg.dataset.batch_size_pergpu *
                                  len(cfg.train.gpus),
                                  shuffle=cfg.dataset.shuffle,
                                  num_workers=cfg.dataset.num_workers)

    model = Classifier(cfg)
    if (torch.cuda.is_available()):
        model.cuda()

    criterion = loss[cfg.Loss.val](cfg)
    optimizer = optimizers[cfg.optimizer.val](model.parameters(), cfg)
    start_epoch = 1

    if (args.resume_checkpoint_path != ''):
        old_dict = torch.load(args.resume_checkpoint_path)
        model.load_state_dict(old_dict['state_dict'])
        optimizer.load_state_dict(old_dict['optim_dict'])
        start_epoch = old_dict['epoch']

    for epoch in range(start_epoch, cfg.train.n_epochs):
        train_epoch(epoch, train_loader, model, criterion, optimizer, cfg)
        if (validate):
            if (epoch % cfg.valid.frequency):
                validate_model(epoch, model.valid_loader, criterion)

    #Fine tune model with very low learning rate Add example
    if (cfg.train.tune.val):
        set_lr(optimizer, lr=cfg.optimizer.lr / cfg.train.tune.lr_factor)
        train_epoch(epoch + 1, train_loader, model, criterion, optimizer, cfg)
예제 #15
0
feature_param_file = opt.feat
class_param_file = opt.cls
bsize = opt.b

# models
if 'vgg' == opt.i:
    feature = Vgg16()
elif 'resnet' == opt.i:
    feature = resnet50()
elif 'densenet' == opt.i:
    feature = densenet121()
feature.cuda()
feature.load_state_dict(torch.load(feature_param_file))

classifier = Classifier(opt.i)
classifier.cuda()
classifier.load_state_dict(torch.load(class_param_file))

loader = torch.utils.data.DataLoader(
    MyClsTestData(test_dir, transform=True),
    batch_size=bsize, shuffle=True, num_workers=4, pin_memory=True)

it = 0.0
num_correct = 0
for ib, (data, lbl) in enumerate(loader):
    inputs = Variable(data.float()).cuda()
    lbl = lbl.cuda()
    it+=lbl.size(0)
    feats = feature(inputs)
    output = classifier(feats)
    _, pred_lbl = torch.max(output, 1)
예제 #16
0
s2_classifier.load_state_dict(
    torch.load(
        osp.join(
            MAIN_DIR,
            "MSDA/A_W_2_D_Open/bvlc_A_W_2_D/pretrain/office-home/bvlc_s2_cls.pth"
        )))
s3_classifier.load_state_dict(
    torch.load(
        osp.join(
            MAIN_DIR,
            "MSDA/A_W_2_D_Open/bvlc_A_W_2_D/pretrain/office-home/bvlc_s3_cls.pth"
        )))
s1_classifier = nn.DataParallel(s1_classifier)
s2_classifier = nn.DataParallel(s2_classifier)
s3_classifier = nn.DataParallel(s3_classifier)
s1_classifier = s1_classifier.cuda()
s2_classifier = s2_classifier.cuda()
s3_classifier = s3_classifier.cuda()

s1_t_discriminator = nn.DataParallel(s1_t_discriminator)
s1_t_discriminator = s1_t_discriminator.cuda()
s2_t_discriminator = nn.DataParallel(s2_t_discriminator)
s2_t_discriminator = s2_t_discriminator.cuda()
s3_t_discriminator = nn.DataParallel(s3_t_discriminator)
s3_t_discriminator = s3_t_discriminator.cuda()


def print_log(step, epoch, epoches, lr, l1, l2, l3, l4, l5, l6, l7, l8, l9,
              l10, l11, l12, flag, ploter, count):
    logger.info("Step [%d/%d] Epoch [%d/%d] lr: %f, s1_cls_loss: %.4f, s2_cls_loss: %.4f,s3_cls_loss: %.4f, s1_t_dis_loss: %.4f, " \
          "s2_t_dis_loss: %.4f, s3_t_dis_loss: %.4f, s1_t_confusion_loss_s1: %.4f, s1_t_confusion_loss_t: %.4f, " \
예제 #17
0
파일: train.py 프로젝트: Tommy-Xu/CHIM
word_size = len(word_dict)

model = Model(word_size,
              label_size,
              category_sizes,
              word_dim,
              embed_dim,
              hidden_dim,
              category_dim,
              inject_type,
              inject_locs,
              chunk_ratio=basis,
              basis=basis)
with torch.no_grad():
    model.embedding.weight.set_(torch.from_numpy(word_vectors).float())
model.cuda()
optimizer = torch.optim.Adadelta(model.parameters())
#if os.path.exists(model_file):
#    best_point = torch.load(model_file)
#    model.load_state_dict(best_point['state_dict'])
#    optimizer.load_state_dict(best_point['optimizer'])
#    best_dev_acc = best_point['dev_acc']
print("Total Parameters:", sum(p.numel() for p in model.parameters()))

x_train = train_data['x']
c_train = train_data['c']
y_train = train_data['y']
x_dev = dev_data['x']
c_dev = dev_data['c']
y_dev = dev_data['y']
예제 #18
0
def train(task,
          phase,
          num_class,
          num_words,
          logs_dir,
          models_dir,
          datafolds,
          seed,
          num_folds=10,
          glove=None,
          epochs=20,
          batch_size=25,
          input_size=100,
          hidden_size=50,
          lr=0.008,
          lr_milestones=None,
          weight_decay=1e-4,
          log_iteration_interval=500,
          use_gpu=False):
    config_string = '{}_{}_batchsize{}_input{}_hidden{}_lr{}{}_wc{}{}_seed{}'.format(
        task, phase, batch_size, input_size, hidden_size, lr,
        '' if not lr_milestones
        else '_ms' + ','.join([str(i) for i in lr_milestones]),
        weight_decay, '_glove' if glove is not None else '', seed)
    log_train_path = os.path.join(logs_dir,
                                  'train_{}.txt'.format(config_string))
    log_eval_path = os.path.join(logs_dir, 'eval_{}.txt'.format(config_string))
    print('[INFO] {}'.format(config_string))
    classifier = Classifier(input_size, hidden_size, num_class, num_words,
                            glove, use_gpu)
    criterion = nn.CrossEntropyLoss()
    if use_gpu:
        classifier = classifier.cuda()
        criterion = criterion.cuda()
    optimizer = optim.Adam(
        [p for p in classifier.parameters() if p.requires_grad],
        lr=lr,
        weight_decay=weight_decay)
    dataset_train = list()
    for i in range(num_folds - 1):
        dataset_train += datafolds[i]
    dataset_eval = datafolds[-1]
    scheduler = None if not lr_milestones else optim.lr_scheduler.MultiStepLR(
        optimizer, lr_milestones, gamma=0.5)
    # train
    for epoch in range(epochs):
        if scheduler is not None:
            scheduler.step()
        random.shuffle(dataset_train)
        optimizer.zero_grad()
        log_loss = 0
        for iteration in range(1, len(dataset_train) + 1):
            tree_root, label = dataset_train[iteration - 1]
            output = classifier(tree_root)
            target = Variable(torch.LongTensor([label]), requires_grad=False)
            if use_gpu:
                target = Variable(
                    torch.LongTensor([label]).cuda(), requires_grad=False)
            loss = criterion(output, target)
            loss.backward()
            if iteration % batch_size == 0:
                optimizer.step()
                optimizer.zero_grad()
            # log
            log_loss += loss.data[0] / log_iteration_interval
            if iteration % log_iteration_interval == 0:
                add_log(log_train_path, '{} {} {}'.format(
                    time.ctime(), iteration, log_loss))
                log_loss = 0
        # evaluate
        correct, total = classifier.evalute_dataset(dataset_eval)
        add_log(log_eval_path, '{} / {} = {:.3f}'.format(
            correct, total,
            float(correct) / total))
        # save checkpoint
        checkpoint = {
            'model': classifier.state_dict(),
            'optimizer': optimizer,
            'epoch': epoch,
            'config_string': config_string
        }
        checkpoint_path = os.path.join(models_dir, '{}_epoch{}.pth'.format(
            config_string, epoch))
        torch.save(checkpoint, checkpoint_path)