コード例 #1
0
def train(name, model, training_data, validation_data, crit, optimizer, scheduler, opt):

    valid_aucs = [0.]
    for epoch_i in range(opt.epoch):
        print('[ Epoch', epoch_i, ']')

        start = time.time()
        train_loss, train_auc = train_epoch(model, training_data, crit, optimizer)
        print('  - (Training)   loss: {loss: 8.5f}, auc: {auc:3.3f} %, '\
              'elapse: {elapse:3.3f} min'.format(
                  loss=train_loss, auc=100*train_auc,
                  elapse=(time.time()-start)/60))
        
        
        start = time.time()
        valid_loss, valid_auc, valid_proba = eval_epoch(model, validation_data, crit)

        print('  - (Validation) loss: {loss: 8.5f}, auc: {auc:3.3f} %, '\
                'elapse: {elapse:3.3f} min'.format(
                    loss=valid_loss, auc=100*valid_auc,
                    elapse=(time.time()-start)/60))
        
        best_loss = max(valid_aucs)
        valid_aucs += [valid_auc]
        scheduler.step(valid_loss)

        model_state_dict = model.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'settings': opt,
            'epoch': epoch_i,
            'auc': valid_auc}

        model_name = name + '.chkpt'
        if valid_auc >= best_loss:
            print('new best loss:', valid_auc)
            best_proba = valid_proba
            best_model = model
            if opt.save_model:
                torch.save(checkpoint, 'models/'+model_name)
                print('    - [Info] The checkpoint file has been updated.')

        if opt.log:
            directory = 'predictions/' + opt.name
            log_train_file = directory + '/train.log'
            log_valid_file = directory + '/valid.log'

            with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
                log_tf.write('{fold},{epoch},{loss: 8.5f},{auc:3.3f}\n'.format(
                    fold=name, epoch=epoch_i, loss=train_loss, auc=100*train_auc))
                log_vf.write('{fold},{epoch},{loss: 8.5f},{auc:3.3f}\n'.format(
                    fold=name, epoch=epoch_i, loss=valid_loss, auc=100*valid_auc))

    return best_model, best_proba
コード例 #2
0
ファイル: trainQGenBelief.py プロジェクト: timbmg/belief
def main(args):

    print(args)

    ts = datetime.datetime.now().timestamp()

    logger = SummaryWriter(
        os.path.join('exp/qgenbelief/', '{}_{}'.format(args.exp_name, ts)))
    logger.add_text('exp_name', args.exp_name)
    logger.add_text('args', str(args))

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    vocab = Vocab(os.path.join(args.data_dir, 'vocab.csv'), args.min_occ)
    category_vocab = CategoryVocab(
        os.path.join(args.data_dir, 'categories.csv'))

    data_loader = OrderedDict()
    splits = ['train', 'valid']

    # Dataset options
    ds_kwargs = dict()

    if args.mrcnn:
        ds_kwargs['mrcnn_objects'] = True
        ds_kwargs['mrcnn_settings'] = \
            {'filter_category': True, 'skip_below_05': True}

    load_vgg_features = args.visual_representation == 'vgg' \
        and args.visual_embedding_dim > 0
    if load_vgg_features:
        ds_kwargs['load_vgg_features'] = True

    load_resnet_features = args.visual_representation == 'resnet-mlb' \
        and args.visual_embedding_dim > 0
    if load_resnet_features:
        ds_kwargs['load_resnet_features'] = True

    for split in splits:
        file = os.path.join(args.data_dir, 'guesswhat.' + split + '.jsonl.gz')
        data_loader[split] = DataLoader(
            dataset=QuestionerDataset(split,
                                      file,
                                      vocab,
                                      category_vocab,
                                      True,
                                      cumulative_dialogue=True,
                                      **ds_kwargs),
            batch_size=args.batch_size,
            shuffle=split == 'train',
            #collate_fn=QuestionerDataset.get_collate_fn(device),
            collate_fn=QuestionerDataset.collate_fn,
            num_workers=args.num_workers)

    guesser = Guesser.load(device, file=args.guesser_file)
    if not args.train_guesser_setting:
        for p in guesser.parameters():
            p.requires_grad = False

    num_additional_features = 0 if args.no_belief_state_input \
        else args.category_embedding_dim
    qgen = QGen(len(vocab), args.word_embedding_dim, args.num_visual_features,
                args.visual_embedding_dim, args.hidden_size,
                num_additional_features).to(device)

    model = QGenBelief(qgen, guesser, args.category_embedding_dim,
                       args.object_embedding_setting,
                       args.object_probs_setting, args.train_guesser_setting,
                       args.visual_representation, args.num_visual_features,
                       args.visual_query).to(device)

    print(model)
    logger.add_text('model', str(model))

    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    if args.train_guesser_setting:
        optimizer_guesser = torch.optim.Adam(guesser.parameters(),
                                             lr=args.learning_rate_guesser)
        optimizer = [optimizer, optimizer_guesser]

    forward_kwargs_mapping = {
        'dialogue': 'source_dialogue',
        'dialogue_lengths': 'dialogue_lengths',
        'cumulative_dialogue': 'cumulative_dialogue',
        'cumulative_lengths': 'cumulative_lengths',
        'num_questions': 'num_questions',
        'object_categories': 'object_categories',
        'object_bboxes': 'object_bboxes',
        'num_objects': 'num_objects',
        'question_lengths': 'question_lengths'
    }
    if args.mrcnn:
        forward_kwargs_mapping['guesser_visual_features'] = \
            'mrcnn_visual_features'
    if load_vgg_features:
        forward_kwargs_mapping['visual_features'] = 'vgg_features'
    if load_resnet_features:
        forward_kwargs_mapping['resnet_features'] = 'resnet_features'

    target_kwarg = 'target_dialogue'

    best_val_loss = 1e9

    for epoch in range(args.epochs):
        train_loss, train_acc = eval_epoch(model, data_loader['train'],
                                           forward_kwargs_mapping,
                                           target_kwarg, loss_fn, optimizer)

        valid_loss, valid_acc = eval_epoch(model, data_loader['valid'],
                                           forward_kwargs_mapping,
                                           target_kwarg, loss_fn)

        if valid_loss < best_val_loss:
            best_val_loss = valid_loss
            model.save(
                os.path.join('bin',
                             'qgenbelief_{}_{}.pt'.format(args.exp_name, ts)))

        logger.add_scalar('train_loss', train_loss, epoch)
        logger.add_scalar('valid_loss', valid_loss, epoch)

        print(("Epoch {:2d}/{:2d} Train Loss {:07.4f} Vaild Loss {:07.4f}"
               ).format(epoch, args.epochs, train_loss, valid_loss))
コード例 #3
0
def main(args):
    print(args)

    if not args.eval:
        ts = datetime.datetime.now().timestamp()
        logger = SummaryWriter(
            os.path.join('exp/guesser/',
                         '{}_{}_{}'.format(args.exp_name, args.setting, ts)))
        logger.add_text('exp_name', args.exp_name)
        logger.add_text('args', str(args))

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    vocab = Vocab(os.path.join(args.data_dir, 'vocab.csv'), args.min_occ)
    category_vocab = CategoryVocab(
        os.path.join(args.data_dir, 'categories.csv'))

    data_loader = OrderedDict()
    if not args.eval:
        splits = ['train', 'valid']
    else:
        splits = ['valid', 'test']

    ds_kwargs = dict()
    if args.setting == 'mrcnn':
        ds_kwargs['mrcnn_objects'] = True
        ds_kwargs['mrcnn_settings'] = \
            {'filter_category': True, 'skip_below_05': True}
    for split in splits:
        file = os.path.join(args.data_dir,
                            'guesswhat.{}.jsonl.gz'.format(split))
        data_loader[split] = DataLoader(
            dataset=QuestionerDataset(split, file, vocab, category_vocab,
                                      not args.eval, **ds_kwargs),
            batch_size=args.batch_size,
            shuffle=split == 'train',
            #collate_fn=QuestionerDataset.get_collate_fn(device))
            collate_fn=QuestionerDataset.collate_fn)
        if args.setting == 'mrcnn':
            logger.add_text("{}_num_datapoints".format(split),
                            str(len(data_loader[split].dataset)))
            logger.add_text("{}_skipped_datapoints".format(split),
                            str(data_loader[split].dataset.skipped_datapoints))

    if not args.eval:
        model = Guesser(len(vocab), args.word_embedding_dim,
                        len(category_vocab), args.category_embedding_dim,
                        args.hidden_size, args.mlp_hidden,
                        args.setting).to(device)
    else:
        model = Guesser.load(device, file=args.bin)
    print(model)

    class_weight = torch.Tensor(data_loader['train'].dataset.category_weights)\
        .to(device) if args.weight_loss else None
    loss_fn = torch.nn.CrossEntropyLoss(weight=class_weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    forward_kwargs_mapping = {
        'dialogue': 'source_dialogue',
        'dialogue_lengths': 'dialogue_lengths'
    }

    if args.setting in 'baseline':
        forward_kwargs_mapping['object_categories'] = 'object_categories'
        forward_kwargs_mapping['object_bboxes'] = 'object_bboxes'
        forward_kwargs_mapping['num_objects'] = 'num_objects'
        target_kwarg = 'target_id'
    elif args.setting == 'category-only':
        target_kwarg = 'target_category'
    elif args.setting in 'mrcnn':
        forward_kwargs_mapping['object_categories'] = 'object_categories'
        forward_kwargs_mapping['object_bboxes'] = 'object_bboxes'
        forward_kwargs_mapping['num_objects'] = 'num_objects'
        forward_kwargs_mapping['visual_features'] = 'mrcnn_visual_features'
        target_kwarg = 'target_id'

    best_val_acc = 0
    for epoch in range(args.epochs):
        if not args.eval:
            train_loss, train_acc = eval_epoch(model, data_loader['train'],
                                               forward_kwargs_mapping,
                                               target_kwarg, loss_fn,
                                               optimizer)

        valid_loss, valid_acc = eval_epoch(model, data_loader['valid'],
                                           forward_kwargs_mapping,
                                           target_kwarg, loss_fn)

        if args.eval:
            test_loss, test_acc = eval_epoch(model, data_loader['test'],
                                             forward_kwargs_mapping,
                                             target_kwarg, loss_fn)

            print("Valid Loss {:07.4f} Valid Acc {:07.4f}".format(
                valid_loss, valid_acc * 100))
            print("Test Loss {:07.4f} Test Acc {:07.4f}".format(
                test_loss, test_acc * 100))

            break

        else:

            if valid_acc > best_val_acc:
                best_val_acc = valid_acc
                model.save(
                    os.path.join(
                        'bin',
                        'guesser_{}_{}_{}.pt'.format(args.exp_name,
                                                     args.setting, ts)))

            logger.add_scalar('train_loss', train_loss, epoch)
            logger.add_scalar('valid_loss', valid_loss, epoch)
            logger.add_scalar('train_acc', train_acc, epoch)
            logger.add_scalar('valid_acc', valid_acc, epoch)

            print(
                ("Epoch {:2d}/{:2d} Train Loss {:07.4f} Vaild Loss {:07.4f} " +
                 "Train Acc {:07.4f} Vaild Acc {:07.4f}").format(
                     epoch, args.epochs, train_loss, valid_loss,
                     train_acc * 100, valid_acc * 100))
コード例 #4
0
def main(args):
    print(args)

    if not args.eval:
        ts = datetime.datetime.now().timestamp()
        logger = SummaryWriter(
            os.path.join('exp/oracle/', '{}_{}'.format(args.exp_name, ts)))
        logger.add_text('exp_name', args.exp_name)
        logger.add_text('args', str(args))

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    vocab = Vocab(os.path.join(args.data_dir, 'vocab.csv'), args.min_occ)
    category_vocab = CategoryVocab(
        os.path.join(args.data_dir, 'categories.csv'))

    # if args.use_film:
    #     filmed_resnet = filmed_resnet50(langugae_embedding_size=args.hidden_size)
    #     # filmwrapper = torch.nn.DataParallel(filmwrapper, device_ids=[0, 1, 2, 3])
    # else:
    #     filmed_resnet = None

    #global_film = MultiHopFiLM(128, args.hidden_size, 128)
    global_film = None
    crop_film = None

    data_loader = OrderedDict()
    if not args.eval:
        splits = ['train', 'valid']
    else:
        splits = ['valid', 'test']

    for split in splits:
        file = os.path.join(args.data_dir, 'guesswhat.' + split + '.jsonl.gz')
        data_loader[split] = DataLoader(
            dataset=OracleDataset(file,
                                  vocab,
                                  category_vocab,
                                  True,
                                  load_crops=args.use_film,
                                  crops_folder=args.crops_folder,
                                  global_features=args.global_features,
                                  global_mapping=args.global_mapping,
                                  crop_features=args.crop_features,
                                  crop_mapping=args.crop_mapping),
            batch_size=args.batch_size,
            shuffle=split == 'train',
            collate_fn=OracleDataset.get_collate_fn(device))

    if not args.eval:
        model = Oracle(len(vocab), args.word_embedding_dim,
                       len(category_vocab), args.category_embedding_dim,
                       args.hidden_size, args.mlp_hidden, global_film,
                       crop_film).to(device)
    else:
        model = Oracle.load(device, file=args.bin)

    loss_fn = torch.nn.CrossEntropyLoss()

    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    forward_kwargs_mapping = {
        'question': 'question',
        'question_lengths': 'question_lengths',
        #'question_mask': 'question_mask',
        'object_categories': 'target_category',
        'object_bboxes': 'target_bbox'
    }
    #'global_features': 'global_features',
    #'crop_features': 'crop_features'}
    if args.use_film:
        forward_kwargs_mapping['crop'] = 'crop'
    target_kwarg = 'target_answer'

    best_val_acc = 0

    for epoch in range(args.epochs):
        if not args.eval:
            train_loss, train_acc = eval_epoch(
                model,
                data_loader['train'],
                forward_kwargs_mapping,
                target_kwarg,
                loss_fn,
                optimizer,
                clip_norm_args=[args.clip_value])

        valid_loss, valid_acc = eval_epoch(model, data_loader['valid'],
                                           forward_kwargs_mapping,
                                           target_kwarg, loss_fn)
        if args.eval:
            test_loss, test_acc = eval_epoch(model, data_loader['test'],
                                             forward_kwargs_mapping,
                                             target_kwarg, loss_fn)

            print("Valid Loss {:07.4f} Valid Acc {:07.4f}".format(
                valid_loss, valid_acc))
            print("Test Loss {:07.4f} Test Acc {:07.4f}".format(
                test_loss, test_acc))

            break

        else:
            if valid_acc > best_val_acc:
                best_val_acc = valid_acc
                model.save(
                    os.path.join('bin',
                                 'oracle_{}_{}.pt'.format(args.exp_name, ts)))

            logger.add_scalar('train_loss', train_loss, epoch)
            logger.add_scalar('valid_loss', valid_loss, epoch)
            logger.add_scalar('train_acc', train_acc, epoch)
            logger.add_scalar('valid_acc', valid_acc, epoch)

            print(
                ("Epoch {:2d}/{:2d} Train Loss {:07.4f} Vaild Loss {:07.4f} " +
                 "Train Acc {:07.4f} Vaild Acc {:07.4f}").format(
                     epoch, args.epochs, train_loss, valid_loss,
                     train_acc * 100, valid_acc * 100))
コード例 #5
0
loss_fun = torch.nn.MSELoss()

min_mse = 100
train_mse = []
valid_mse = []
test_mse = []

for i in range(n_epochs):
    start = time.time()
    scheduler.step()

    model.train()
    # use train_epoch_scale/eval_epoch_scale for training scale equivariant models
    train_mse.append(train_epoch(train_loader, model, optimizer, loss_fun))
    model.eval()
    mse, _, _ = eval_epoch(valid_loader, model, loss_fun)
    valid_mse.append(mse)

    if valid_mse[-1] < min_mse:
        min_mse = valid_mse[-1]
        best_model = model
        torch.save(best_model, save_name + ".pth")
    end = time.time()

    # Early Stopping but train at least for 50 epochs
    if (len(train_mse) > 50
            and np.mean(valid_mse[-5:]) >= np.mean(valid_mse[-10:-5])):
        break
    print(i + 1, train_mse[-1], valid_mse[-1], round((end - start) / 60, 5),
          format(get_lr(optimizer), "5.2e"))
コード例 #6
0
ファイル: trainQGen.py プロジェクト: timbmg/belief
def main(args):

    print(args)

    ts = datetime.datetime.now().timestamp()

    logger = SummaryWriter(
        os.path.join('exp/qgen/', '{}_{}'.format(args.exp_name, ts)))
    logger.add_text('exp_name', args.exp_name)
    logger.add_text('args', str(args))

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    vocab = Vocab(os.path.join(args.data_dir, 'vocab.csv'), args.min_occ)
    category_vocab = CategoryVocab(
        os.path.join(args.data_dir, 'categories.csv'))

    load_vgg_features, load_resnet_features = False, False
    if args.visual_representation == 'vgg':
        load_vgg_features = True
    elif args.visual_representation == 'resnet-mlb':
        load_resnet_features = True

    data_loader = OrderedDict()
    splits = ['train', 'valid']

    for split in splits:
        file = os.path.join(args.data_dir, 'guesswhat.' + split + '.jsonl.gz')
        data_loader[split] = DataLoader(
            dataset=QuestionerDataset(
                split,
                file,
                vocab,
                category_vocab,
                True,
                load_vgg_features=load_vgg_features,
                load_resnet_features=load_resnet_features),
            batch_size=args.batch_size,
            shuffle=split == 'train',
            #collate_fn=QuestionerDataset.get_collate_fn(device),
            collate_fn=QuestionerDataset.collate_fn)

    model = QGen(len(vocab),
                 args.word_embedding_dim,
                 args.num_visual_features,
                 args.visual_embedding_dim,
                 args.hidden_size,
                 visual_representation=args.visual_representation,
                 query_tokens=vocab.answer_tokens).to(device)
    print(model)

    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    forward_kwargs_mapping = {
        'dialogue': 'source_dialogue',
        'dialogue_lengths': 'dialogue_lengths'
    }
    if load_vgg_features:
        forward_kwargs_mapping['visual_features'] = 'vgg_features'
    if load_resnet_features:
        forward_kwargs_mapping['visual_features'] = 'resnet_features'

    target_kwarg = 'target_dialogue'

    best_val_loss = 1e9
    for epoch in range(args.epochs):
        train_loss, _ = eval_epoch(model, data_loader['train'],
                                   forward_kwargs_mapping, target_kwarg,
                                   loss_fn, optimizer)

        valid_loss, _ = eval_epoch(model, data_loader['valid'],
                                   forward_kwargs_mapping, target_kwarg,
                                   loss_fn)

        if valid_loss < best_val_loss:
            best_val_loss = valid_loss
            model.save(
                os.path.join('bin', 'qgen_{}_{}.pt'.format(args.exp_name, ts)))

        logger.add_scalar('train_loss', train_loss, epoch)
        logger.add_scalar('valid_loss', valid_loss, epoch)

        print(("Epoch {:2d}/{:2d} Train Loss {:07.4f} Vaild Loss {:07.4f}"
               ).format(epoch, args.epochs, train_loss, valid_loss))
コード例 #7
0
def main(args):

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")

    train_loader, test_loader = load_dataset(args.label, args.batch_size,
                                             args.half_length, args.nholes)

    if args.label == 10:
        model = ShakeResNet(args.depth, args.w_base, args.label)
    else:
        model = ShakeResNeXt(args.depth, args.w_base, args.cardinary,
                             args.label)

    model = torch.nn.DataParallel(model).cuda()

    cudnn.benckmark = True

    if args.optimizer == 'sgd':
        print("using sgd")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay,
                        nesterov=args.nesterov)

    elif args.optimizer == 'abd':
        print("using adabound")
        opt = abd.AdaBound(model.parameters(),
                           lr=args.lr,
                           gamma=args.gamma,
                           weight_decay=args.weight_decay,
                           final_lr=args.final_lr)

    elif args.optimizer == 'swa':
        print("using swa")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
        steps_per_epoch = len(train_loader.dataset) / args.batch_size
        steps_per_epoch = int(steps_per_epoch)
        opt = swa(opt,
                  swa_start=args.swa_start * steps_per_epoch,
                  swa_freq=steps_per_epoch,
                  swa_lr=args.swa_lr)
    else:
        print("not valid optimizer")
        exit

    loss_func = nn.CrossEntropyLoss().cuda()

    headers = [
        "Epoch", "LearningRate", "TrainLoss", "TestLoss", "TrainAcc.",
        "TestAcc."
    ]

    #if args.optimizer=='swa':
    #   headers = headers[:-1] + ['swa_te_loss', 'swa_te_acc'] + headers[-1:]
    #  swa_res = {'loss': None, 'accuracy': None}

    logger = utils.Logger(args.checkpoint, headers, mod=args.optimizer)

    for e in range(args.epochs):

        if args.optimizer == 'swa':
            lr = utils.schedule(e, args.optimizer, args.epochs, args.swa_start,
                                args.swa_lr, args.lr)
            utils.adjust_learning_rate(opt, lr)
        elif args.optimizer == 'sgd':
            lr = utils.cosine_lr(opt, args.lr, e, args.epochs)
        else:
            exit

        #train
        train_loss, train_acc, train_n = utils.train_epoch(
            train_loader, model, opt)
        #eval
        test_loss, test_acc, test_n = utils.eval_epoch(test_loader, model)

        logger.write(e + 1, lr, train_loss / train_n, test_loss / test_n,
                     train_acc / train_n * 100, test_acc / test_n * 100)

        if args.optimizer == 'swa' and (
                e + 1) >= args.swa_start and args.eval_freq > 1:
            if e == 0 or e % args.eval_freq == args.eval_freq - 1 or e == args.epochs - 1:
                opt.swap_swa_sgd()
                opt.bn_update(train_loaders, model, device='cuda')
                #swa_res = utils.eval_epoch(test_loaders['test'], model)
                opt.swap_swa_sgd()