#cnnp = CNN('./cnn_config_1.5TP_optimal.json', 0) #this part is for 5-runs accuracy #-------------------------------- if False: oris = [] gens = [] ps1s = [] rs1s = [] fs1s = [] ps2s = [] rs2s = [] fs2s = [] for seed in range(5): cnn = CNN('./cnn_config_1.5T_optimal.json', 0) cnn.train(verbose=1) print(cnn.epoch) cnn.model.load_state_dict(torch.load('{}CNN_{}.pth'.format(cnn.checkpoint_dir, cnn.epoch))) cnnp = CNN('./cnn_config_1.5TP_optimal.json', 0) cnnp.train(verbose=1) print(cnnp.epoch) cnnp.model.load_state_dict(torch.load('{}CNN_{}.pth'.format(cnnp.checkpoint_dir, cnnp.epoch))) print('iter', seed, 'testing accuracy:', cnn.test(), cnnp.test()) ori, gen = eval_cnns(cnn, cnnp) ps_1, rs_1, fs_1, ps_2, rs_2, fs_2 = PRF_cnns(cnn, cnnp) ps1s += [ps_1] rs1s += [rs_1] fs1s += [fs_1] ps2s += [ps_2] rs2s += [rs_2]
def pre_train(dataloader, test_loader, dict_loader, dataloader_test, mask_labels, total_epochs=50, learning_rate=1e-4, use_gpu=True, seed=123): args = parser.parse_args() pprint(args) num_bits = args.num_bits model = CNN(model_name='alexnet', bit=num_bits, class_num=args.num_class) criterion = custom_loss(num_bits=num_bits) arch = 'cnn_' filename = arch + args.dataset + '_' + str(num_bits) + "bits" checkpoint_filename = os.path.join(args.checkpoint, filename + '.pt') if use_gpu: model = model.cuda() model = torch.nn.DataParallel(model, device_ids=range( torch.cuda.device_count())) criterion = criterion.cuda() torch.cuda.manual_seed(seed) running_loss = 0.0 start_epoch = 0 batch_time = AverageMeter() data_time = AverageMeter() end = time.time() best_prec = -99999 k = 10500 n_samples = 200000 alpha = 0.4 alpha_1 = 0.99 mask_labels = torch.from_numpy(mask_labels).long().cuda() Z_h1 = torch.zeros(n_samples, num_bits).float().cuda() # intermediate values z_h1 = torch.zeros(n_samples, num_bits).float().cuda() # temporal outputs h1 = torch.zeros(n_samples, num_bits).float().cuda() # current outputs Z_h2 = torch.zeros(args.anchor_num, num_bits).float().cuda() # intermediate values z_h2 = torch.zeros(args.anchor_num, num_bits).float().cuda() # temporal outputs h2 = torch.zeros(args.anchor_num, num_bits).float().cuda() # current outputs for epoch in range(start_epoch, total_epochs): model.train(True) rampup_value = rampup(epoch) rampdown_value = rampdown(epoch) learning_rate = rampup_value * rampdown_value * 0.00005 adam_beta1 = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5 adam_beta2 = step_rampup(epoch) * 0.99 + (1 - step_rampup(epoch)) * 0.999 if epoch == 0: u_w = 0.0 else: u_w = rampup_value u_w_m = u_w * 5 u_w_m = torch.autograd.Variable(torch.FloatTensor([u_w_m]).cuda(), requires_grad=False) optimizer = Adam(model.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2), eps=1e-8, amsgrad=True) anchors_data, anchor_Label = generate_anchor_vectors(dict_loader) for iteration, data in enumerate(dataloader, 0): anchor_index = np.arange(args.anchor_num) np.random.shuffle(anchor_index) anchor_index = anchor_index[:100] anchor_index = torch.from_numpy(anchor_index).long().cuda() anchor_inputs = anchors_data[anchor_index, :, :, :] anchor_labels = anchor_Label[anchor_index, :] inputs, labels, index = data['image'], data['labels'], data[ 'index'] labels = labels.float() mask_flag = Variable(mask_labels[index], requires_grad=False) idx = (mask_flag > 0) if index.shape[0] == args.batch_size: anchor_batch_S, anchor_batch_W = CalcSim( labels[idx, :].cuda(), anchor_labels.cuda()) if inputs.size(3) == 3: inputs = inputs.permute(0, 3, 1, 2) inputs = inputs.type(torch.FloatTensor) zcomp_h1 = z_h1[index.cuda(), :] zcomp_h2 = z_h2[anchor_index, :] labeled_batch_S, labeled_batch_W = CalcSim( labels[idx, :].cuda(), labels[idx, :].cuda()) if use_gpu: inputs = Variable(inputs.cuda(), requires_grad=False) anchor_batch_S = Variable(anchor_batch_S.cuda(), requires_grad=False) anchor_batch_W = Variable(anchor_batch_W.cuda(), requires_grad=False) labeled_batch_S = Variable(labeled_batch_S.cuda(), requires_grad=False) labeled_batch_W = Variable(labeled_batch_W.cuda(), requires_grad=False) # zero the parameter gradients optimizer.zero_grad() y_h1 = model(inputs) y_h2 = model(anchor_inputs) y = F.sigmoid(48 / num_bits * 0.4 * torch.matmul(y_h1, y_h2.permute(1, 0))) loss, l_batch_loss, m_loss = criterion( y, y_h1, y_h2, anchor_batch_S, anchor_batch_W, labeled_batch_S, labeled_batch_W, zcomp_h1, zcomp_h2, mask_flag, u_w_m, epoch, num_bits) h1[index, :] = y_h1.data.clone() h2[anchor_index, :] = y_h2.data.clone() # backward+optimize loss.backward() optimizer.step() running_loss += loss.item() Z_h2 = alpha_1 * Z_h2 + (1. - alpha_1) * h2 z_h2 = Z_h2 * (1. / (1. - alpha_1**(epoch + 1))) print( "Epoch[{}]({}/{}): Time:(data {:.3f}/ batch {:.3f}) Loss_H: {:.4f}/{:.4f}/{:.4f}" .format(epoch, iteration, len(dataloader), data_time.val, batch_time.val, loss.item(), l_batch_loss.item(), m_loss.item())) Z_h1 = alpha * Z_h1 + (1. - alpha) * h1 z_h1 = Z_h1 * (1. / (1. - alpha**(epoch + 1))) if epoch % 1 == 0: MAP = helpers.validate(model, dataloader_test, test_loader) print("Test image map is:{}".format(MAP)) is_best = MAP > best_prec best_prec = max(best_prec, MAP) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, is_best, prefix=arch, num_bits=num_bits, filename=checkpoint_filename) return model
class Mind: def __init__(self, n_actions, input_shape, save_path=None, load_path=None, action_inverses={}, update_interval=256 * 2, save_interval=5): self.n_actions = n_actions self.save_path = save_path self.network = CNN(n_out=n_actions, input_shape=input_shape) if load_path != None: self.network.load(load_path) self.n_features = input_shape self.data = [] self.current_episode_count = 1 self.random_actions = 0 self.last_action = None self.last_action_random = False self.action_inverses = action_inverses self.lifetime = 1 self.update_interval = update_interval self.save_interval = save_interval self.n_updates = 1 def q(self, state): return self.network.predict(np.expand_dims(np.array(state), axis=0))[0] def should_explore(self, state): if np.random.random() < 1000 / (1000 + self.lifetime): return True return False def explore_action(self, state): return np.random.randint(0, self.n_actions) def action(self, state): q = self.q(state) # if self.last_action_random: # if self.last_action in self.action_inverses: # q[self.action_inverses[self.last_action]] = float('-inf') action = np.argmax(q) if self.should_explore(state): self.random_actions += 1 action = self.explore_action(state) self.last_action_random = True else: self.last_action_random = False if self.lifetime % self.update_interval == 0: self.update(alpha=0.9) self.n_updates += 1 if self.n_updates % self.save_interval == 0: if self.save_path != None: self.save(self.save_path) print('saved') self.last_action = action self.current_episode_count += 1 self.lifetime += 1 return action def save(self, path): self.network.save(path) def reset(self): self.count = 1 print('Random actions: ', self.random_actions) self.random_actions = 0 def q_target(self, reward, best_next, alpha): return reward + alpha * best_next def feedback(self, old_action, old_state, reward, new_state): self.data.append({ 'Q_max': np.max(self.q(new_state)), 'reward': reward, 'old_state': old_state, 'old_action': old_action }) def update(self, alpha=0.6): np.random.shuffle(self.data) samples = self.data self.data = [] states = [] ys = [] for sample in samples: y = self.q(sample['old_state']) y[sample['old_action']] = self.q_target(sample['reward'], best_next=sample['Q_max'], alpha=alpha) #y[sample['old_action']] = sarsa_target(sample['reward'], next_action = sample['Q_max'], alpha = alpha) states.append(sample['old_state']) ys.append(y) self.network.train(np.array(states), np.array(ys))