예제 #1
0
파일: main.py 프로젝트: yxue3357/mammoth
def main():
    lecun_fix()
    parser = ArgumentParser(description='mammoth', allow_abbrev=False)
    parser.add_argument('--model',
                        type=str,
                        required=True,
                        help='Model name.',
                        choices=get_all_models())
    parser.add_argument('--load_best_args',
                        action='store_true',
                        help='Loads the best arguments for each method, '
                        'dataset and memory buffer.')
    add_management_args(parser)
    args = parser.parse_known_args()[0]
    mod = importlib.import_module('models.' + args.model)

    if args.load_best_args:
        parser.add_argument('--dataset',
                            type=str,
                            required=True,
                            choices=DATASET_NAMES,
                            help='Which dataset to perform experiments on.')
        if hasattr(mod, 'Buffer'):
            parser.add_argument('--buffer_size',
                                type=int,
                                required=True,
                                help='The size of the memory buffer.')
        args = parser.parse_args()
        if args.model == 'joint':
            best = best_args[args.dataset]['sgd']
        else:
            best = best_args[args.dataset][args.model]
        if args.model == 'joint' and args.dataset == 'mnist-360':
            args.model = 'joint_gcl'
        if hasattr(args, 'buffer_size'):
            best = best[args.buffer_size]
        else:
            best = best[-1]
        for key, value in best.items():
            setattr(args, key, value)
    else:
        get_parser = getattr(mod, 'get_parser')
        parser = get_parser()
        args = parser.parse_args()

    if args.seed is not None:
        set_random_seed(args.seed)

    if args.model == 'mer': setattr(args, 'batch_size', 1)
    dataset = get_dataset(args)
    backbone = dataset.get_backbone()
    loss = dataset.get_loss()
    model = get_model(args, backbone, loss, dataset.get_transform())

    if isinstance(dataset, ContinualDataset):
        train(model, dataset, args)
    else:
        assert not hasattr(model, 'end_task') or model.NAME == 'joint_gcl'
        ctrain(args)
예제 #2
0
        print('done')

        print('arguments for train:')
        print(train_args)
        
        print('rebuilding model...')
        model = Set2Seq(voc.num_words).to(args.device)
        model.load_state_dict(checkpoint['model'])
        model_optimizer = train_args.optimiser(model.parameters(), lr=args.learning_rate)
        model_optimizer.load_state_dict(checkpoint['opt'])
        print('done')

    print('loading test data...')
    test_set = PairDataset(voc, dataset_file_path=args.test_file, reverse=True)
    print('done')
    
    test_seq_acc, test_tok_acc, test_loss = eval_model(model, test_set)
    print("[TEST]Loss: {:.4f}; Seq-level Accuracy: {:.4f}; Tok-level Accuracy: {:.4f}".format(
                test_loss, test_seq_acc * 100, test_tok_acc * 100)
         )


if __name__ == '__main__':
    set_random_seed(args.seed)
    with autograd.detect_anomaly():
        print('with detect_anomaly')
        if args.test:
            test()
        else:
            train()
예제 #3
0
def main():
    if torch.cuda.device_count() > 1:
        torch.set_num_threads(6 * torch.cuda.device_count())
    else:
        torch.set_num_threads(2)
    parser = ArgumentParser(description='mammoth', allow_abbrev=False)
    parser.add_argument('--model',
                        type=str,
                        required=True,
                        help='Model name.',
                        choices=get_all_models())
    parser.add_argument('--load_best_args',
                        action='store_true',
                        help='Loads the best arguments for each method, '
                        'dataset and memory buffer.')
    add_management_args(parser)
    args = parser.parse_known_args()[0]
    mod = importlib.import_module('models.' + args.model)

    if args.load_best_args:
        parser.add_argument('--dataset',
                            type=str,
                            required=True,
                            choices=DATASET_NAMES,
                            help='Which dataset to perform experiments on.')
        if hasattr(mod, 'Buffer'):
            parser.add_argument('--buffer_size',
                                type=int,
                                required=True,
                                help='The size of the memory buffer.')
        args = parser.parse_args()
        model = args.model
        if model == 'joint':
            model = 'sgd'
        best = best_args[args.dataset][model]
        if hasattr(args, 'buffer_size'):
            best = best[args.buffer_size]
        else:
            best = best[-1]
        for key, value in best.items():
            setattr(args, key, value)
    else:
        get_parser = getattr(mod, 'get_parser')
        parser = get_parser()
        args = parser.parse_args()

    if args.seed is not None:
        set_random_seed(args.seed)

    off_joint = False
    if args.model == 'joint' and args.dataset == 'seq-core50':
        args.dataset = 'seq-core50j'
        args.model = 'sgd'
        off_joint = True

    dataset = get_dataset(args)

    # continual learning
    backbone = dataset.get_backbone()
    loss = dataset.get_loss()
    model = get_model(args, backbone, loss, dataset.get_transform())
    if off_joint:
        print('BEGIN JOINT TRAINING')
        jtrain(model, dataset, args)
    else:
        print('BEGIN CONTINUAL TRAINING')
        train(model, dataset, args)
예제 #4
0
        checkpoint = torch.load(args.param_file, map_location=args.device)
        train_args = checkpoint['args']
        voc = checkpoint['voc']
        print('done')

        print('arguments for train:')
        print(train_args)

        print('rebuilding model...')
        model = Set2Seq2Choice(voc.num_words).to(args.device)
        model.load_state_dict(checkpoint['model'])
        print('done')

    print('loading test data...')
    test_set = ChooseDataset(voc, dataset_file_path=args.test_file)
    print('done')

    test_acc, test_loss = eval_model(model, test_set)
    print("[TEST]Loss: {:.4f}; Accuracy: {:.4f};".format(
        test_loss, test_acc * 100))


if __name__ == '__main__':
    set_random_seed(1234)
    with autograd.detect_anomaly():
        print('with detect_anomaly')
        if args.test:
            test()
        else:
            train()
예제 #5
0
파일: main.py 프로젝트: SunWenJu123/ILCOC
def main():
    parser = ArgumentParser(description='mammoth', allow_abbrev=False)
    args = parser.parse_known_args()[0]
    args.model = 'ocilfast'
    args.seed = None
    args.validation = True

    args.img_dir = 'img/test'  # 打印图片存储路径
    args.print_file = open('../'+args.img_dir+'/result.txt', mode='w')

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    """
    # seq-tinyimagenet
    args.dataset = 'seq-tinyimg'
    args.lr = 2e-3
    args.batch_size = 32  
    args.buffer_size = 0
    args.minibatch_size = 32  
    args.n_epochs = 100 

    args.nu = 0.7  
    args.eta = 0.04  
    args.eps = 1 
    args.embedding_dim = 250  
    args.weight_decay = 1e-2  
    args.margin = 1  
    args.r = 0.01  
    args.nf = 32
    """

    # seq-cifar10
    args.dataset = 'seq-cifar10'
    args.lr = 1e-3
    args.batch_size = 32
    args.buffer_size = 0
    args.minibatch_size = 32
    args.n_epochs = 50

    args.nu = 0.7
    args.eta = 0.8
    args.eps = 1
    args.embedding_dim = 250
    args.weight_decay = 1e-2
    args.margin = 1
    args.r = 0.1
    args.nf = 32

    # seq-mnist
    # args.dataset = 'seq-mnist'
    # args.buffer_size = 0
    # args.lr = 1e-3
    # args.batch_size = 128
    # args.minibatch_size = 128
    # args.n_epochs = 10
    #
    # args.nu = 0.8
    # args.eta = 1
    # args.eps = 0.1
    # args.embedding_dim = 150
    # args.weight_decay = 0
    # args.margin = 5
    # args.r = 0.1                # 半径
    # args.nf = 32

    if args.seed is not None:
        set_random_seed(args.seed)

    train(args)