Example #1
0
    # print(dictionaries[0])

    print('Initializing CLEVR dataset...')

    # Build the model
    n_words = len(dictionaries[0]) + 1
    n_choices = len(dictionaries[1])

    print('n_words = {}, n_choices = {}'.format(n_words, n_choices))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules,
                      n_choices, args.n_nodes, n_words, args.embed_size,
                      args.lstm_hid_dim, args.input_feat_dim, args.map_dim,
                      args.mlp_hid_dim, args.mem_dim, args.kb_dim,
                      args.kernel_size, device).to(device)

    # load checkpoint
    model.load_state_dict(torch.load(args.ckpt))
    model = nn.DataParallel(model)

    clevr_test = ClevrDataset(args.clevr_dir,
                              split='test',
                              features_dir=args.features_dir,
                              dictionaries=dictionaries)
    test_set = DataLoader(clevr_test,
                          batch_size=args.batch_size,
                          num_workers=1)
Example #2
0
def main():
    parser = argparse.ArgumentParser(description='Stack-NMN')
    parser.add_argument('--embed_size', type=int, help='embedding dim. of question words', default=300)
    parser.add_argument('--lstm_hid_dim', type=int, help='hidden dim. of LSTM', default=256)
    parser.add_argument('--input_feat_dim', type=int, help='feat dim. of image features', default=1024)
    parser.add_argument('--map_dim', type=int, help='hidden dim. size of intermediate attention maps', default=512)
    parser.add_argument('--text_param_dim', type=int, help='hidden dim. of textual param.', default=512)
    parser.add_argument('--mlp_hid_dim', type=int, help='hidden dim. of mlp', default=512)
    parser.add_argument('--mem_dim', type=int, help='hidden dim. of mem.', default=512)
    parser.add_argument('--kb_dim', type=int, help='hidden dim. of conv features.', default=512)
    parser.add_argument('--max_stack_len', type=int, help='max. length of stack', default=8)
    parser.add_argument('--max_time_stamps', type=int, help='max. number of time-stamps for modules', default=9)
    parser.add_argument('--clevr_dir', type=str, help='Directory of CLEVR dataset', required=True)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--n_epochs', type=int, default=50)
    parser.add_argument('--n_modules', type=int, default=9)
    parser.add_argument('--kernel_size', type=int, default=3)
    parser.add_argument('--model_dir', type=str, required=True)
    parser.add_argument('--features_dir', type=str, default='data')
    parser.add_argument('--clevr_feature_dir', type=str, default='/u/username/data/clevr_features/')
    parser.add_argument('--copy_data', action='store_true')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--reg_coeff', type=float, default=1e-1)
    parser.add_argument('--ckpt', type=str, default='')
    parser.add_argument('--resume', action='store_true') # use only on slurm
    parser.add_argument('--optim', type=str, default='adam')
    parser.add_argument('--use_half', action='store_true') # use only on slurm

    # SGDR hyper-params
    parser.add_argument('--T0', type=int, default=1)
    parser.add_argument('--Tmult', type=int, default=2)
    parser.add_argument('--eta_min', type=float, default=1e-5)

    args = parser.parse_args()
    print(args)
    '''
    with open('data/dic.pkl', 'rb') as f1:
        dic = pickle.load(f1)

    n_words = len(dic['word_dic']) + 1
    n_choices = len(dic['answer_dic'])

    print('n_words = {}, n_choices = {}'.format(n_words, n_choices))
    '''

    print('Building word dictionaries from all the words in the dataset...')

    dictionaries = utils_clevr_humans.build_dictionaries(args.clevr_dir)
    print('Building word dictionary completed!')

    print('Initializing CLEVR dataset...')

    # Build the model
    n_words = len(dictionaries[0])+1
    n_choices = len(dictionaries[1])

    print('n_words = {}, n_choices = {}'.format(n_words, n_choices))

    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    writer = SummaryWriter(log_dir=args.model_dir)

    if args.copy_data:
        start_time = time.time()
        copytree(args.clevr_feature_dir, os.path.join(os.path.expandvars('$SLURM_TMPDIR'),'clevr_features/'))
        # copytree('/u/username/data/clevr_features/', '/Tmp/username/clevr_features/')
        # args.features_dir = '/Tmp/username/clevr_features/'
        args.features_dir = os.path.join(os.path.expandvars('$SLURM_TMPDIR'),'clevr_features/')
        print('data copy finished in {} sec.'.format(time.time() - start_time))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print('device = {}'.format(device))

    model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules, n_choices, 3, n_words, args.embed_size, args.lstm_hid_dim, args.input_feat_dim, args.map_dim, args.mlp_hid_dim, args.mem_dim, args.kb_dim, args.kernel_size, False, device).to(device)
    # model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()

    # optimizer_1 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' not in p[0], model.named_parameters())), lr=args.lr, weight_decay=0e-3)

    # optimizer_2 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' in p[0], model.named_parameters())), lr=0e-8, weight_decay=0e-2)
    if args.optim == 'adam':
        optimizer = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)
    elif args.optim == 'asgd':
        optimizer = optim.ASGD(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)
    elif args.optim == 'adamax':
        optimizer = optim.Adamax(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)
    elif args.optim == 'adadelta':
        optimizer = optim.Adadelta(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)
    elif args.optim == 'sgdr':
        optimizer = optim.SGD(map(lambda p:p[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)

    clevr_val = ClevrDataset(args.clevr_dir, split='val', features_dir=args.features_dir, dictionaries=dictionaries)

    val_set = DataLoader(clevr_val, batch_size=args.batch_size, num_workers=4, collate_fn=collate_data)

    start_epoch = 0
    
    if len(args.ckpt)>0:
        model.load_state_dict({k:v for k,v in torch.load(args.ckpt).items() if 'embed' not in k}, strict=False)
        # start_epoch = int(args.ckpt.split('_')[-1].split('.')[0])
        # print('start_epoch = {}'.format(start_epoch))
        prev_embed = torch.load(args.ckpt)['embed.weight']
        model.embed.weight.data[:prev_embed.size(0), :].copy_(prev_embed)

    # print(model.embed.weight.data)
    
    if args.resume:
        model_ckpts = list(filter(lambda x:'ckpt_epoch' in x, os.listdir(args.model_dir)))
        
        if len(model_ckpts)>0:
            model_ckpts_epoch_ids = [int(filename.split('_')[-1].split('.')[0]) for filename in model_ckpts]
            start_epoch = max(model_ckpts_epoch_ids)
            latest_ckpt_file = os.path.join(args.model_dir, 'ckpt_epoch_{}.model'.format(start_epoch))
            model.load_state_dict(torch.load(latest_ckpt_file))
            print('Loaded ckpt file {}'.format(latest_ckpt_file))
            print('start_epoch = {}'.format(start_epoch))

    val(model, criterion, optimizer, val_set, args.batch_size, device, writer, args.n_epochs)

    writer.close()
Example #3
0
def main():
    parser = argparse.ArgumentParser(description='Stack-NMN')
    parser.add_argument('--embed_size',
                        type=int,
                        help='embedding dim. of question words',
                        default=300)
    parser.add_argument('--lstm_hid_dim',
                        type=int,
                        help='hidden dim. of LSTM',
                        default=256)
    parser.add_argument('--input_feat_dim',
                        type=int,
                        help='feat dim. of image features',
                        default=1024)
    parser.add_argument('--map_dim',
                        type=int,
                        help='hidden dim. size of intermediate attention maps',
                        default=512)
    # parser.add_argument('--text_param_dim', type=int, help='hidden dim. of textual param.', default=512)
    parser.add_argument('--mlp_hid_dim',
                        type=int,
                        help='hidden dim. of mlp',
                        default=512)
    parser.add_argument('--mem_dim',
                        type=int,
                        help='hidden dim. of mem.',
                        default=512)
    parser.add_argument('--kb_dim',
                        type=int,
                        help='hidden dim. of conv features.',
                        default=512)
    parser.add_argument('--max_stack_len',
                        type=int,
                        help='max. length of stack',
                        default=8)
    parser.add_argument('--max_time_stamps',
                        type=int,
                        help='max. number of time-stamps for modules',
                        default=9)
    parser.add_argument('--clevr_dir',
                        type=str,
                        help='Directory of CLEVR dataset',
                        required=True)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--n_epochs', type=int, default=50)
    parser.add_argument('--n_modules', type=int,
                        default=7)  # includes 1 NoOp module
    parser.add_argument('--n_nodes', type=int,
                        default=2)  # TODO: change/tune later
    parser.add_argument('--kernel_size', type=int, default=3)
    parser.add_argument('--model_dir', type=str, required=True)
    parser.add_argument('--features_dir', type=str, default='data')
    parser.add_argument('--clevr_feature_dir',
                        type=str,
                        default='/u/username/data/clevr_features/')
    parser.add_argument('--copy_data', action='store_true')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--temperature', type=float, default=0.2)
    parser.add_argument('--reg_coeff', type=float, default=1e-2)
    parser.add_argument('--ckpt', type=str, default='')
    parser.add_argument('--resume', action='store_true')  # use only on slurm
    parser.add_argument('--reg_coeff_op_loss', type=float, default=1e-1)

    # DARTS args
    parser.add_argument('--unrolled',
                        action='store_true',
                        default=False,
                        help='use one-step unrolled validation loss')
    parser.add_argument('--arch_learning_rate',
                        type=float,
                        default=3e-4,
                        help='learning rate for arch encoding')
    parser.add_argument('--arch_weight_decay',
                        type=float,
                        default=1e-3,
                        help='weight decay for arch encoding')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    args = parser.parse_args()

    print('Building word dictionaries from all the words in the dataset...')

    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Building word dictionary completed!')

    print('Initializing CLEVR dataset...')

    # Build the model
    n_words = len(dictionaries[0]) + 1
    n_choices = len(dictionaries[1])

    print('n_words = {}, n_choices = {}'.format(n_words, n_choices))

    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    writer = SummaryWriter(log_dir=args.model_dir)

    if args.copy_data:
        start_time = time.time()
        copytree(
            args.clevr_feature_dir,
            os.path.join(os.path.expandvars('$SLURM_TMPDIR'),
                         'clevr_features/'))
        # copytree('/u/username/data/clevr_features/', '/Tmp/username/clevr_features/')
        # args.features_dir = '/Tmp/username/clevr_features/'
        args.features_dir = os.path.join(os.path.expandvars('$SLURM_TMPDIR'),
                                         'clevr_features/')
        print('data copy finished in {} sec.'.format(time.time() - start_time))

    # TODO: remove this later
    # args.features_dir = '/Tmp/username/clevr_features/'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print('device = {}'.format(device))

    model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules,
                      n_choices, args.n_nodes, args.temperature, n_words,
                      args.embed_size, args.lstm_hid_dim, args.input_feat_dim,
                      args.map_dim, args.mlp_hid_dim, args.mem_dim,
                      args.kb_dim, args.kernel_size, device).to(device)
    # model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules, n_choices, n_words, args.embed_size, args.lstm_hid_dim, args.input_feat_dim, args.map_dim, args.text_param_dim, args.mlp_hid_dim, args.kernel_size, device).to(device)
    start_epoch = 0

    if len(args.ckpt) > 0:
        model.load_state_dict(torch.load(args.ckpt))
        start_epoch = int(args.ckpt.split('_')[-1].split('.')[0])
        print('start_epoch = {}'.format(start_epoch))

    if args.resume:
        model_ckpts = list(
            filter(lambda x: 'ckpt_epoch' in x, os.listdir(args.model_dir)))

        if len(model_ckpts) > 0:
            model_ckpts_epoch_ids = [
                int(filename.split('_')[-1].split('.')[0])
                for filename in model_ckpts
            ]
            start_epoch = max(model_ckpts_epoch_ids)
            latest_ckpt_file = os.path.join(
                args.model_dir, 'ckpt_epoch_{}.model'.format(start_epoch))
            model.load_state_dict(torch.load(latest_ckpt_file))
            print('Loaded ckpt file {}'.format(latest_ckpt_file))
            print('start_epoch = {}'.format(start_epoch))

    model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()

    # optimizer_1 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' not in p[0], model.named_parameters())), lr=args.lr, weight_decay=0e-3)

    # optimizer_2 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' in p[0], model.named_parameters())), lr=0e-8, weight_decay=0e-2)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.module.network_parameters()),
                           lr=args.lr)
    # optimizer = optim.Adam(filter(lambda p:p.requires_grad, model.network_parameters()), lr=args.lr)
    # optimizer = optim.Adam(map(lambda x:x[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)

    architect = Architect(model, device, args)

    clevr_train = ClevrDataset(args.clevr_dir,
                               split='train',
                               features_dir=args.features_dir,
                               dictionaries=dictionaries)
    clevr_val = ClevrDataset(args.clevr_dir,
                             split='val',
                             features_dir=args.features_dir,
                             dictionaries=dictionaries)

    train_set = DataLoader(clevr_train,
                           batch_size=args.batch_size,
                           num_workers=0,
                           collate_fn=collate_data)
    val_set = DataLoader(clevr_val,
                         batch_size=args.batch_size,
                         num_workers=0,
                         collate_fn=collate_data)
    val_set_architect = DataLoader(clevr_val,
                                   batch_size=args.batch_size,
                                   num_workers=0,
                                   collate_fn=collate_data,
                                   sampler=torch.utils.data.RandomSampler(
                                       list(range(len(clevr_val)))))

    for epoch_id in range(start_epoch, args.n_epochs):
        train(epoch_id, model, architect, criterion, optimizer, train_set,
              val_set_architect, args.batch_size, device, writer,
              args.n_epochs, args.lr, args.unrolled, args.reg_coeff)
        valid(epoch_id, model, criterion, optimizer, val_set, args.batch_size,
              device, writer, args.n_epochs)

        with open(
                '{}/ckpt_epoch_{}.model'.format(args.model_dir,
                                                str(epoch_id + 1)),
                'wb') as f1:
            torch.save(model.module.state_dict(), f1)

    clevr_train.close()
    clevr_val.close()
    writer.close()