Ejemplo n.º 1
0
    #cnnp = CNN('./cnn_config_1.5TP_optimal.json', 0)

    #this part is for 5-runs accuracy
    #--------------------------------
    if False:
        oris = []
        gens = []
        ps1s = []
        rs1s = []
        fs1s = []
        ps2s = []
        rs2s = []
        fs2s = []
        for seed in range(5):
            cnn = CNN('./cnn_config_1.5T_optimal.json', 0)
            cnn.train(verbose=1)
            print(cnn.epoch)
            cnn.model.load_state_dict(torch.load('{}CNN_{}.pth'.format(cnn.checkpoint_dir, cnn.epoch)))

            cnnp = CNN('./cnn_config_1.5TP_optimal.json', 0)
            cnnp.train(verbose=1)
            print(cnnp.epoch)
            cnnp.model.load_state_dict(torch.load('{}CNN_{}.pth'.format(cnnp.checkpoint_dir, cnnp.epoch)))
            print('iter', seed, 'testing accuracy:', cnn.test(), cnnp.test())
            ori, gen = eval_cnns(cnn, cnnp)
            ps_1, rs_1, fs_1, ps_2, rs_2, fs_2 = PRF_cnns(cnn, cnnp)
            ps1s += [ps_1]
            rs1s += [rs_1]
            fs1s += [fs_1]
            ps2s += [ps_2]
            rs2s += [rs_2]
Ejemplo n.º 2
0
def pre_train(dataloader,
              test_loader,
              dict_loader,
              dataloader_test,
              mask_labels,
              total_epochs=50,
              learning_rate=1e-4,
              use_gpu=True,
              seed=123):

    args = parser.parse_args()
    pprint(args)

    num_bits = args.num_bits

    model = CNN(model_name='alexnet', bit=num_bits, class_num=args.num_class)

    criterion = custom_loss(num_bits=num_bits)

    arch = 'cnn_'
    filename = arch + args.dataset + '_' + str(num_bits) + "bits"
    checkpoint_filename = os.path.join(args.checkpoint, filename + '.pt')

    if use_gpu:
        model = model.cuda()
        model = torch.nn.DataParallel(model,
                                      device_ids=range(
                                          torch.cuda.device_count()))
        criterion = criterion.cuda()
        torch.cuda.manual_seed(seed)

    running_loss = 0.0

    start_epoch = 0
    batch_time = AverageMeter()
    data_time = AverageMeter()
    end = time.time()

    best_prec = -99999

    k = 10500
    n_samples = 200000

    alpha = 0.4
    alpha_1 = 0.99

    mask_labels = torch.from_numpy(mask_labels).long().cuda()

    Z_h1 = torch.zeros(n_samples,
                       num_bits).float().cuda()  # intermediate values
    z_h1 = torch.zeros(n_samples, num_bits).float().cuda()  # temporal outputs
    h1 = torch.zeros(n_samples, num_bits).float().cuda()  # current outputs

    Z_h2 = torch.zeros(args.anchor_num,
                       num_bits).float().cuda()  # intermediate values
    z_h2 = torch.zeros(args.anchor_num,
                       num_bits).float().cuda()  # temporal outputs
    h2 = torch.zeros(args.anchor_num,
                     num_bits).float().cuda()  # current outputs

    for epoch in range(start_epoch, total_epochs):
        model.train(True)
        rampup_value = rampup(epoch)
        rampdown_value = rampdown(epoch)
        learning_rate = rampup_value * rampdown_value * 0.00005
        adam_beta1 = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5
        adam_beta2 = step_rampup(epoch) * 0.99 + (1 -
                                                  step_rampup(epoch)) * 0.999

        if epoch == 0:
            u_w = 0.0
        else:
            u_w = rampup_value

        u_w_m = u_w * 5

        u_w_m = torch.autograd.Variable(torch.FloatTensor([u_w_m]).cuda(),
                                        requires_grad=False)

        optimizer = Adam(model.parameters(),
                         lr=learning_rate,
                         betas=(adam_beta1, adam_beta2),
                         eps=1e-8,
                         amsgrad=True)

        anchors_data, anchor_Label = generate_anchor_vectors(dict_loader)

        for iteration, data in enumerate(dataloader, 0):

            anchor_index = np.arange(args.anchor_num)
            np.random.shuffle(anchor_index)

            anchor_index = anchor_index[:100]

            anchor_index = torch.from_numpy(anchor_index).long().cuda()

            anchor_inputs = anchors_data[anchor_index, :, :, :]
            anchor_labels = anchor_Label[anchor_index, :]

            inputs, labels, index = data['image'], data['labels'], data[
                'index']

            labels = labels.float()

            mask_flag = Variable(mask_labels[index], requires_grad=False)
            idx = (mask_flag > 0)

            if index.shape[0] == args.batch_size:
                anchor_batch_S, anchor_batch_W = CalcSim(
                    labels[idx, :].cuda(), anchor_labels.cuda())

                if inputs.size(3) == 3:
                    inputs = inputs.permute(0, 3, 1, 2)
                inputs = inputs.type(torch.FloatTensor)

                zcomp_h1 = z_h1[index.cuda(), :]
                zcomp_h2 = z_h2[anchor_index, :]

                labeled_batch_S, labeled_batch_W = CalcSim(
                    labels[idx, :].cuda(), labels[idx, :].cuda())

                if use_gpu:
                    inputs = Variable(inputs.cuda(), requires_grad=False)
                    anchor_batch_S = Variable(anchor_batch_S.cuda(),
                                              requires_grad=False)
                    anchor_batch_W = Variable(anchor_batch_W.cuda(),
                                              requires_grad=False)
                    labeled_batch_S = Variable(labeled_batch_S.cuda(),
                                               requires_grad=False)
                    labeled_batch_W = Variable(labeled_batch_W.cuda(),
                                               requires_grad=False)

                # zero the parameter gradients
                optimizer.zero_grad()

                y_h1 = model(inputs)
                y_h2 = model(anchor_inputs)

                y = F.sigmoid(48 / num_bits * 0.4 *
                              torch.matmul(y_h1, y_h2.permute(1, 0)))

                loss, l_batch_loss, m_loss = criterion(
                    y, y_h1, y_h2, anchor_batch_S, anchor_batch_W,
                    labeled_batch_S, labeled_batch_W, zcomp_h1, zcomp_h2,
                    mask_flag, u_w_m, epoch, num_bits)

                h1[index, :] = y_h1.data.clone()
                h2[anchor_index, :] = y_h2.data.clone()

                # backward+optimize
                loss.backward()

                optimizer.step()

                running_loss += loss.item()

                Z_h2 = alpha_1 * Z_h2 + (1. - alpha_1) * h2
                z_h2 = Z_h2 * (1. / (1. - alpha_1**(epoch + 1)))

        print(
            "Epoch[{}]({}/{}): Time:(data {:.3f}/ batch {:.3f}) Loss_H: {:.4f}/{:.4f}/{:.4f}"
            .format(epoch, iteration, len(dataloader), data_time.val,
                    batch_time.val, loss.item(), l_batch_loss.item(),
                    m_loss.item()))

        Z_h1 = alpha * Z_h1 + (1. - alpha) * h1
        z_h1 = Z_h1 * (1. / (1. - alpha**(epoch + 1)))

        if epoch % 1 == 0:
            MAP = helpers.validate(model, dataloader_test, test_loader)

            print("Test image map is:{}".format(MAP))

            is_best = MAP > best_prec
            best_prec = max(best_prec, MAP)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                prefix=arch,
                num_bits=num_bits,
                filename=checkpoint_filename)

    return model
Ejemplo n.º 3
0
class Mind:
    def __init__(self,
                 n_actions,
                 input_shape,
                 save_path=None,
                 load_path=None,
                 action_inverses={},
                 update_interval=256 * 2,
                 save_interval=5):
        self.n_actions = n_actions
        self.save_path = save_path
        self.network = CNN(n_out=n_actions, input_shape=input_shape)

        if load_path != None:
            self.network.load(load_path)

        self.n_features = input_shape
        self.data = []
        self.current_episode_count = 1

        self.random_actions = 0

        self.last_action = None
        self.last_action_random = False

        self.action_inverses = action_inverses

        self.lifetime = 1
        self.update_interval = update_interval
        self.save_interval = save_interval
        self.n_updates = 1

    def q(self, state):
        return self.network.predict(np.expand_dims(np.array(state), axis=0))[0]

    def should_explore(self, state):
        if np.random.random() < 1000 / (1000 + self.lifetime):
            return True
        return False

    def explore_action(self, state):
        return np.random.randint(0, self.n_actions)

    def action(self, state):
        q = self.q(state)
        #        if self.last_action_random:
        #            if self.last_action in self.action_inverses:
        #                q[self.action_inverses[self.last_action]] = float('-inf')

        action = np.argmax(q)

        if self.should_explore(state):
            self.random_actions += 1
            action = self.explore_action(state)
            self.last_action_random = True
        else:
            self.last_action_random = False

        if self.lifetime % self.update_interval == 0:
            self.update(alpha=0.9)
            self.n_updates += 1
            if self.n_updates % self.save_interval == 0:
                if self.save_path != None:
                    self.save(self.save_path)
                    print('saved')

        self.last_action = action
        self.current_episode_count += 1
        self.lifetime += 1

        return action

    def save(self, path):
        self.network.save(path)

    def reset(self):
        self.count = 1
        print('Random actions: ', self.random_actions)
        self.random_actions = 0

    def q_target(self, reward, best_next, alpha):
        return reward + alpha * best_next

    def feedback(self, old_action, old_state, reward, new_state):
        self.data.append({
            'Q_max': np.max(self.q(new_state)),
            'reward': reward,
            'old_state': old_state,
            'old_action': old_action
        })

    def update(self, alpha=0.6):
        np.random.shuffle(self.data)
        samples = self.data
        self.data = []
        states = []
        ys = []

        for sample in samples:
            y = self.q(sample['old_state'])
            y[sample['old_action']] = self.q_target(sample['reward'],
                                                    best_next=sample['Q_max'],
                                                    alpha=alpha)
            #y[sample['old_action']]  = sarsa_target(sample['reward'], next_action = sample['Q_max'], alpha = alpha)

            states.append(sample['old_state'])

            ys.append(y)
        self.network.train(np.array(states), np.array(ys))