Exemplo n.º 1
0
def main(args):
    assert os.path.isfile(args.checkpoint), "Checkpoint file not found: {}".format(args.checkpoint)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    test_transforms = transforms.Compose([transforms.Resize((128, 128)),
                                          transforms.ToTensor()])

    # Initialize CLEVR Loader
    clevr_dataset_images = ClevrDatasetImages(args.clevr_dir, 'val', test_transforms)
    clevr_feat_extraction_loader = DataLoader(clevr_dataset_images, batch_size=args.batch_size,
                                              shuffle=False, num_workers=8, drop_last=True)

    args.features_dirs = './features'
    if not os.path.exists(args.features_dirs):
        os.makedirs(args.features_dirs)

    max_features = os.path.join(args.features_dirs, 'max_features.pickle')
    avg_features = os.path.join(args.features_dirs, 'avg_features.pickle')

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Word dictionary completed!')

    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])
    model = RN(args)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    # Load the model checkpoint
    print('==> loading checkpoint {}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint)

    #removes 'module' from dict entries, pytorch bug #3805
    checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()}

    model.load_state_dict(checkpoint)
    print('==> loaded checkpoint {}'.format(args.checkpoint))

    max_features = open(max_features, 'wb')
    avg_features = open(avg_features, 'wb')

    extract_features_rl(clevr_feat_extraction_loader, max_features, avg_features, model, args)
Exemplo n.º 2
0
def main(args):
    args.model_dirs = './model_{}_b{}_lr{}'.format(args.model, args.batch_size,
                                                   args.lr)
    args.features_dirs = './features'
    if not os.path.exists(args.model_dirs):
        os.makedirs(args.model_dirs)

    args.test_results_dir = './test_results'
    if not os.path.exists(args.test_results_dir):
        os.makedirs(args.test_results_dir)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Word dictionary completed!')

    print('Initializing CLEVR dataset...')

    if (not args.state_description):
        train_transforms = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.Pad(8),
            transforms.RandomCrop((128, 128)),
            transforms.RandomRotation(2.8),  # .05 rad
            transforms.ToTensor()
        ])
        test_transforms = transforms.Compose(
            [transforms.Resize((128, 128)),
             transforms.ToTensor()])

        clevr_dataset_train = ClevrDataset(args.clevr_dir, True, dictionaries,
                                           train_transforms)
        clevr_dataset_test = ClevrDataset(args.clevr_dir, False, dictionaries,
                                          test_transforms)

        # Use a weighted sampler for training:
        weights = clevr_dataset_train.answer_weights()
        sampler = torch.utils.data.sampler.WeightedRandomSampler(
            weights, len(weights))

        # Initialize Clevr dataset loaders
        clevr_train_loader = DataLoader(clevr_dataset_train,
                                        batch_size=args.batch_size,
                                        sampler=sampler,
                                        num_workers=8,
                                        collate_fn=utils.collate_samples_image)
        clevr_test_loader = DataLoader(clevr_dataset_test,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=8,
                                       collate_fn=utils.collate_samples_image)
    else:
        clevr_dataset_train = ClevrDatasetStateDescription(
            args.clevr_dir, True, dictionaries)
        clevr_dataset_test = ClevrDatasetStateDescription(
            args.clevr_dir, False, dictionaries)

        # Initialize Clevr dataset loaders
        clevr_train_loader = DataLoader(
            clevr_dataset_train,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            collate_fn=utils.collate_samples_state_description)
        clevr_test_loader = DataLoader(
            clevr_dataset_test,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8,
            collate_fn=utils.collate_samples_state_description)

    print('CLEVR dataset initialized!')

    # Build the model
    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])
    model = RN(args)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    start_epoch = 1
    if args.resume:
        filename = args.resume
        if os.path.isfile(filename):
            print('==> loading checkpoint {}'.format(filename))
            checkpoint = torch.load(filename)

            #removes 'module' from dict entries, pytorch bug #3805
            #checkpoint = {k.replace('module.',''): v for k,v in checkpoint.items()}

            model.load_state_dict(checkpoint)
            print('==> loaded checkpoint {}'.format(filename))
            start_epoch = int(
                re.match(r'.*epoch_(\d+).pth', args.resume).groups()[0]) + 1

    if args.conv_transfer_learn:
        if os.path.isfile(args.conv_transfer_learn):
            # TODO: there may be problems caused by pytorch issue #3805 if using DataParallel

            print('==> loading conv layer from {}'.format(
                args.conv_transfer_learn))
            # pretrained dict is the dictionary containing the already trained conv layer
            pretrained_dict = torch.load(args.conv_transfer_learn)

            if torch.cuda.device_count() == 1:
                conv_dict = model.conv.state_dict()
            else:
                conv_dict = model.module.conv.state_dict()

            # filter only the conv layer from the loaded dictionary
            conv_pretrained_dict = {
                k.replace('conv.', '', 1): v
                for k, v in pretrained_dict.items() if 'conv.' in k
            }

            # overwrite entries in the existing state dict
            conv_dict.update(conv_pretrained_dict)

            # load the new state dict
            if torch.cuda.device_count() == 1:
                model.conv.load_state_dict(conv_dict)
                params = model.conv.parameters()
            else:
                model.module.conv.load_state_dict(conv_dict)
                params = model.module.conv.parameters()

            # freeze the weights for the convolutional layer by disabling gradient evaluation
            # for param in params:
            #     param.requires_grad = False

            print("==> conv layer loaded!")
        else:
            print('Cannot load file {}'.format(args.conv_transfer_learn))

    progress_bar = trange(start_epoch, args.epochs + 1)
    if args.test:
        # perform a single test
        print('Testing epoch {}'.format(start_epoch))
        test(clevr_test_loader, model, start_epoch, dictionaries, args)
    else:
        # perform a full training
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr,
                               weight_decay=1e-4)
        print('Training ({} epochs) is starting...'.format(args.epochs))
        for epoch in progress_bar:
            # TRAIN
            progress_bar.set_description('TRAIN')
            train(clevr_train_loader, model, optimizer, epoch, args)
            # TEST
            progress_bar.set_description('TEST')
            test(clevr_test_loader, model, epoch, dictionaries, args)
            # SAVE MODEL
            filename = 'RN_epoch_{:02d}.pth'.format(epoch)
            torch.save(model.state_dict(),
                       os.path.join(args.model_dirs, filename))
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(description='Stack-NMN')
    parser.add_argument('--embed_size',
                        type=int,
                        help='embedding dim. of question words',
                        default=300)
    parser.add_argument('--lstm_hid_dim',
                        type=int,
                        help='hidden dim. of LSTM',
                        default=256)
    parser.add_argument('--input_feat_dim',
                        type=int,
                        help='feat dim. of image features',
                        default=1024)
    parser.add_argument('--map_dim',
                        type=int,
                        help='hidden dim. size of intermediate attention maps',
                        default=512)
    # parser.add_argument('--text_param_dim', type=int, help='hidden dim. of textual param.', default=512)
    parser.add_argument('--mlp_hid_dim',
                        type=int,
                        help='hidden dim. of mlp',
                        default=512)
    parser.add_argument('--mem_dim',
                        type=int,
                        help='hidden dim. of mem.',
                        default=512)
    parser.add_argument('--kb_dim',
                        type=int,
                        help='hidden dim. of conv features.',
                        default=512)
    parser.add_argument('--max_stack_len',
                        type=int,
                        help='max. length of stack',
                        default=8)
    parser.add_argument('--max_time_stamps',
                        type=int,
                        help='max. number of time-stamps for modules',
                        default=9)
    parser.add_argument('--clevr_dir',
                        type=str,
                        help='Directory of CLEVR dataset',
                        required=True)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--n_epochs', type=int, default=50)
    parser.add_argument('--n_modules', type=int,
                        default=7)  # includes 1 NoOp module
    parser.add_argument('--n_nodes', type=int,
                        default=2)  # TODO: change/tune later
    parser.add_argument('--kernel_size', type=int, default=3)
    parser.add_argument('--model_dir', type=str, required=True)
    parser.add_argument('--features_dir', type=str, default='data')
    parser.add_argument('--clevr_feature_dir',
                        type=str,
                        default='/u/username/data/clevr_features/')
    parser.add_argument('--copy_data', action='store_true')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--temperature', type=float, default=0.2)
    parser.add_argument('--reg_coeff', type=float, default=1e-2)
    parser.add_argument('--ckpt', type=str, default='')
    parser.add_argument('--resume', action='store_true')  # use only on slurm
    parser.add_argument('--reg_coeff_op_loss', type=float, default=1e-1)

    # DARTS args
    parser.add_argument('--unrolled',
                        action='store_true',
                        default=False,
                        help='use one-step unrolled validation loss')
    parser.add_argument('--arch_learning_rate',
                        type=float,
                        default=3e-4,
                        help='learning rate for arch encoding')
    parser.add_argument('--arch_weight_decay',
                        type=float,
                        default=1e-3,
                        help='weight decay for arch encoding')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    args = parser.parse_args()

    print('Building word dictionaries from all the words in the dataset...')

    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Building word dictionary completed!')

    print('Initializing CLEVR dataset...')

    # Build the model
    n_words = len(dictionaries[0]) + 1
    n_choices = len(dictionaries[1])

    print('n_words = {}, n_choices = {}'.format(n_words, n_choices))

    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    writer = SummaryWriter(log_dir=args.model_dir)

    if args.copy_data:
        start_time = time.time()
        copytree(
            args.clevr_feature_dir,
            os.path.join(os.path.expandvars('$SLURM_TMPDIR'),
                         'clevr_features/'))
        # copytree('/u/username/data/clevr_features/', '/Tmp/username/clevr_features/')
        # args.features_dir = '/Tmp/username/clevr_features/'
        args.features_dir = os.path.join(os.path.expandvars('$SLURM_TMPDIR'),
                                         'clevr_features/')
        print('data copy finished in {} sec.'.format(time.time() - start_time))

    # TODO: remove this later
    # args.features_dir = '/Tmp/username/clevr_features/'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print('device = {}'.format(device))

    model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules,
                      n_choices, args.n_nodes, args.temperature, n_words,
                      args.embed_size, args.lstm_hid_dim, args.input_feat_dim,
                      args.map_dim, args.mlp_hid_dim, args.mem_dim,
                      args.kb_dim, args.kernel_size, device).to(device)
    # model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules, n_choices, n_words, args.embed_size, args.lstm_hid_dim, args.input_feat_dim, args.map_dim, args.text_param_dim, args.mlp_hid_dim, args.kernel_size, device).to(device)
    start_epoch = 0

    if len(args.ckpt) > 0:
        model.load_state_dict(torch.load(args.ckpt))
        start_epoch = int(args.ckpt.split('_')[-1].split('.')[0])
        print('start_epoch = {}'.format(start_epoch))

    if args.resume:
        model_ckpts = list(
            filter(lambda x: 'ckpt_epoch' in x, os.listdir(args.model_dir)))

        if len(model_ckpts) > 0:
            model_ckpts_epoch_ids = [
                int(filename.split('_')[-1].split('.')[0])
                for filename in model_ckpts
            ]
            start_epoch = max(model_ckpts_epoch_ids)
            latest_ckpt_file = os.path.join(
                args.model_dir, 'ckpt_epoch_{}.model'.format(start_epoch))
            model.load_state_dict(torch.load(latest_ckpt_file))
            print('Loaded ckpt file {}'.format(latest_ckpt_file))
            print('start_epoch = {}'.format(start_epoch))

    model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()

    # optimizer_1 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' not in p[0], model.named_parameters())), lr=args.lr, weight_decay=0e-3)

    # optimizer_2 = optim.Adam(map(lambda p:p[1], filter(lambda p:p[1].requires_grad and 'weight_mlp' in p[0], model.named_parameters())), lr=0e-8, weight_decay=0e-2)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.module.network_parameters()),
                           lr=args.lr)
    # optimizer = optim.Adam(filter(lambda p:p.requires_grad, model.network_parameters()), lr=args.lr)
    # optimizer = optim.Adam(map(lambda x:x[1], filter(lambda p:p[1].requires_grad, model.named_parameters())), lr=args.lr)

    architect = Architect(model, device, args)

    clevr_train = ClevrDataset(args.clevr_dir,
                               split='train',
                               features_dir=args.features_dir,
                               dictionaries=dictionaries)
    clevr_val = ClevrDataset(args.clevr_dir,
                             split='val',
                             features_dir=args.features_dir,
                             dictionaries=dictionaries)

    train_set = DataLoader(clevr_train,
                           batch_size=args.batch_size,
                           num_workers=0,
                           collate_fn=collate_data)
    val_set = DataLoader(clevr_val,
                         batch_size=args.batch_size,
                         num_workers=0,
                         collate_fn=collate_data)
    val_set_architect = DataLoader(clevr_val,
                                   batch_size=args.batch_size,
                                   num_workers=0,
                                   collate_fn=collate_data,
                                   sampler=torch.utils.data.RandomSampler(
                                       list(range(len(clevr_val)))))

    for epoch_id in range(start_epoch, args.n_epochs):
        train(epoch_id, model, architect, criterion, optimizer, train_set,
              val_set_architect, args.batch_size, device, writer,
              args.n_epochs, args.lr, args.unrolled, args.reg_coeff)
        valid(epoch_id, model, criterion, optimizer, val_set, args.batch_size,
              device, writer, args.n_epochs)

        with open(
                '{}/ckpt_epoch_{}.model'.format(args.model_dir,
                                                str(epoch_id + 1)),
                'wb') as f1:
            torch.save(model.module.state_dict(), f1)

    clevr_train.close()
    clevr_val.close()
    writer.close()
Exemplo n.º 4
0
def main(args):
    #load hyperparameters from configuration file
    with open(args.config) as config_file:
        hyp = json.load(config_file)['hyperparams'][args.model]
    #override configuration dropout
    if args.question_injection >= 0:
        hyp['question_injection_position'] = args.question_injection

    print('Loaded hyperparameters from configuration {}, model: {}: {}'.format(
        args.config, args.model, hyp))

    assert os.path.isfile(
        args.checkpoint), "Checkpoint file not found: {}".format(
            args.checkpoint)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # Initialize CLEVR Loader
    clevr_dataset_test = initialize_dataset(
        args.clevr_dir, True if args.set == 'train' else False,
        hyp['state_description'])
    clevr_feat_extraction_loader = reload_loaders(clevr_dataset_test,
                                                  args.batch_size,
                                                  hyp['state_description'])

    args.features_dirs = './features'
    if not os.path.exists(args.features_dirs):
        os.makedirs(args.features_dirs)

    files_dict = {}
    if args.extr_layer_idx >= 0:  #g_layers features
        files_dict['max_features'] = \
            open(os.path.join(args.features_dirs, '{}_2S-RN_max_features.pickle'.format(args.set,args.extr_layer_idx)),'wb')
        files_dict['avg_features'] = \
            open(os.path.join(args.features_dirs, '{}_2S-RN_avg_features.pickle'.format(args.set,args.extr_layer_idx)),'wb')
    else:
        '''files_dict['flatconv_features'] = \
            open(os.path.join(args.features_dirs, '{}_flatconv_features.pickle'.format(args.set)),'wb')'''
        files_dict['avgconv_features'] = \
            open(os.path.join(args.features_dirs, '{}_RN_avg_features.pickle'.format(args.set)),'wb')
        files_dict['maxconv_features'] = \
            open(os.path.join(args.features_dirs, '{}_RN_max_features.pickle'.format(args.set)),'wb')

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Word dictionary completed!')
    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])

    print('Cuda: {}'.format(args.cuda))
    model = RN(args, hyp, extraction=True)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    # Load the model checkpoint
    print('==> loading checkpoint {}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint,
                            map_location=lambda storage, loc: storage)

    #removes 'module' from dict entries, pytorch bug #3805
    #removes 'module' from dict entries, pytorch bug #3805
    if torch.cuda.device_count() == 1 and any(
            k.startswith('module.') for k in checkpoint.keys()):
        print('Removing \'module.\' prefix')
        checkpoint = {
            k.replace('module.', ''): v
            for k, v in checkpoint.items()
        }
    if torch.cuda.device_count() > 1 and not any(
            k.startswith('module.') for k in checkpoint.keys()):
        print('Adding \'module.\' prefix')
        checkpoint = {'module.' + k: v for k, v in checkpoint.items()}

    model.load_state_dict(checkpoint)
    print('==> loaded checkpoint {}'.format(args.checkpoint))

    extract_features_rl(clevr_feat_extraction_loader,
                        hyp['question_injection_position'],
                        args.extr_layer_idx, hyp['lstm_hidden'], files_dict,
                        model, args)
Exemplo n.º 5
0
def main(args):
    #load hyperparameters from configuration file
    with open(args.config) as config_file:
        hyp = json.load(config_file)['hyperparams'][args.model]
    #override configuration dropout
    if args.dropout > 0:
        hyp['dropout'] = args.dropout
    if args.question_injection >= 0:
        hyp['question_injection_position'] = args.question_injection

    print('Loaded hyperparameters from configuration {}, model: {}: {}'.format(
        args.config, args.model, hyp))

    args.model_dirs = '{}/model_{}_drop{}_bstart{}_bstep{}_bgamma{}_bmax{}_lrstart{}_'+ \
                      'lrstep{}_lrgamma{}_lrmax{}_invquests-{}_clipnorm{}_glayers{}_qinj{}_fc1{}_fc2{}'
    args.model_dirs = args.model_dirs.format(
        args.exp_dir, args.model, hyp['dropout'], args.batch_size,
        args.bs_step, args.bs_gamma, args.bs_max, args.lr, args.lr_step,
        args.lr_gamma, args.lr_max, args.invert_questions, args.clip_norm,
        hyp['g_layers'], hyp['question_injection_position'], hyp['f_fc1'],
        hyp['f_fc2'])
    if not os.path.exists(args.model_dirs):
        os.makedirs(args.model_dirs)
    #create a file in this folder containing the overall configuration
    args_str = str(args)
    hyp_str = str(hyp)
    all_configuration = args_str + '\n\n' + hyp_str
    filename = os.path.join(args.model_dirs, 'config.txt')
    with open(filename, 'w') as config_file:
        config_file.write(all_configuration)

    args.features_dirs = '{}/features'.format(args.exp_dir)
    args.test_results_dir = '{}/test_results'.format(args.exp_dir)
    if not os.path.exists(args.test_results_dir):
        os.makedirs(args.test_results_dir)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('Building word dictionaries from all the words in the dataset...')
    dictionaries = utils.build_dictionaries(args.clevr_dir, args.exp_dir)
    print('Word dictionary completed!')

    print('Initializing CLEVR dataset...')
    clevr_dataset_train, clevr_dataset_test = initialize_dataset(
        args.clevr_dir, args.exp_dir, dictionaries, hyp['state_description'])
    print('CLEVR dataset initialized!')

    # Build the model
    args.qdict_size = len(dictionaries[0])
    args.adict_size = len(dictionaries[1])

    model = RN(args, hyp)

    if torch.cuda.device_count() > 1 and args.cuda:
        model = torch.nn.DataParallel(model)
        model.module.cuda()  # call cuda() overridden method

    if args.cuda:
        model.cuda()

    start_epoch = 1
    if args.resume:
        filename = args.resume
        if os.path.isfile(filename):
            print('==> loading checkpoint {}'.format(filename))
            checkpoint = torch.load(filename)

            #removes 'module' from dict entries, pytorch bug #3805
            if torch.cuda.device_count() == 1 and any(
                    k.startswith('module.') for k in checkpoint.keys()):
                checkpoint = {
                    k.replace('module.', ''): v
                    for k, v in checkpoint.items()
                }
            if torch.cuda.device_count() > 1 and not any(
                    k.startswith('module.') for k in checkpoint.keys()):
                checkpoint = {'module.' + k: v for k, v in checkpoint.items()}

            model.load_state_dict(checkpoint)
            print('==> loaded checkpoint {}'.format(filename))
            start_epoch = int(
                re.match(r'.*epoch_(\d+).pth', args.resume).groups()[0]) + 1

    if args.conv_transfer_learn:
        if os.path.isfile(args.conv_transfer_learn):
            # TODO: there may be problems caused by pytorch issue #3805 if using DataParallel

            print('==> loading conv layer from {}'.format(
                args.conv_transfer_learn))
            # pretrained dict is the dictionary containing the already trained conv layer
            pretrained_dict = torch.load(args.conv_transfer_learn)

            if torch.cuda.device_count() == 1:
                conv_dict = model.conv.state_dict()
            else:
                conv_dict = model.module.conv.state_dict()

            # filter only the conv layer from the loaded dictionary
            conv_pretrained_dict = {
                k.replace('conv.', '', 1): v
                for k, v in pretrained_dict.items() if 'conv.' in k
            }

            # overwrite entries in the existing state dict
            conv_dict.update(conv_pretrained_dict)

            # load the new state dict
            if torch.cuda.device_count() == 1:
                model.conv.load_state_dict(conv_dict)
                params = model.conv.parameters()
            else:
                model.module.conv.load_state_dict(conv_dict)
                params = model.module.conv.parameters()

            # freeze the weights for the convolutional layer by disabling gradient evaluation
            # for param in params:
            #     param.requires_grad = False

            print("==> conv layer loaded!")
        else:
            print('Cannot load file {}'.format(args.conv_transfer_learn))

    progress_bar = trange(start_epoch, args.epochs + 1)
    if args.test:
        # perform a single test
        print('Testing epoch {}'.format(start_epoch))
        _, clevr_test_loader = reload_loaders(clevr_dataset_train,
                                              clevr_dataset_test,
                                              args.batch_size,
                                              args.test_batch_size,
                                              hyp['state_description'])
        test(clevr_test_loader, model, start_epoch, dictionaries, args)
    else:
        bs = args.batch_size

        # perform a full training
        #TODO: find a better solution for general lr scheduling policies
        candidate_lr = args.lr * args.lr_gamma**(start_epoch -
                                                 1 // args.lr_step)
        lr = candidate_lr if candidate_lr <= args.lr_max else args.lr_max

        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=lr,
                               weight_decay=1e-4)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, min_lr=1e-6, verbose=True)
        scheduler = lr_scheduler.StepLR(optimizer,
                                        args.lr_step,
                                        gamma=args.lr_gamma)
        scheduler.last_epoch = start_epoch
        print('Training ({} epochs) is starting...'.format(args.epochs))
        for epoch in progress_bar:
            if ((args.bs_max > 0 and bs < args.bs_max)
                    or args.bs_max < 0) and (epoch % args.bs_step == 0
                                             or epoch == start_epoch):
                bs = math.floor(args.batch_size *
                                (args.bs_gamma**(epoch // args.bs_step)))
                if bs > args.bs_max and args.bs_max > 0:
                    bs = args.bs_max
                clevr_train_loader, clevr_test_loader = reload_loaders(
                    clevr_dataset_train, clevr_dataset_test, bs,
                    args.test_batch_size, hyp['state_description'])

                #restart optimizer in order to restart learning rate scheduler
                #for param_group in optimizer.param_groups:
                #    param_group['lr'] = args.lr
                #scheduler = lr_scheduler.CosineAnnealingLR(optimizer, step, min_lr)
                print('Dataset reinitialized with batch size {}'.format(bs))

            if ((args.lr_max > 0 and scheduler.get_lr()[0] < args.lr_max)
                    or args.lr_max < 0):
                scheduler.step()

            print('Current learning rate: {}'.format(
                optimizer.param_groups[0]['lr']))

            # TRAIN
            progress_bar.set_description('TRAIN')
            train(clevr_train_loader, model, optimizer, epoch, args)

            # TEST
            progress_bar.set_description('TEST')
            test(clevr_test_loader, model, epoch, dictionaries, args)

            # SAVE MODEL
            filename = 'RN_epoch_{:02d}.pth'.format(epoch)
            torch.save(model.state_dict(),
                       os.path.join(args.model_dirs, filename))
Exemplo n.º 6
0
                        help='use one-step unrolled validation loss')
    parser.add_argument('--arch_learning_rate',
                        type=float,
                        default=3e-4,
                        help='learning rate for arch encoding')
    parser.add_argument('--arch_weight_decay',
                        type=float,
                        default=1e-3,
                        help='weight decay for arch encoding')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    args = parser.parse_args()

    print('Building word dictionaries from all the words in the dataset...')

    dictionaries = utils.build_dictionaries(args.clevr_dir)
    print('Building word dictionary completed!')

    print('Initializing CLEVR dataset...')

    # Build the model
    n_words = len(dictionaries[0]) + 1
    n_choices = len(dictionaries[1])

    print('n_words = {}, n_choices = {}'.format(n_words, n_choices))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = Stack_NMN(args.max_stack_len, args.max_time_stamps, args.n_modules,
                      n_choices, args.n_nodes, n_words, args.embed_size,
                      args.lstm_hid_dim, args.input_feat_dim, args.map_dim,