def train_inference(imitate_net, path, max_epochs=None, self_training=False, ab=None): if max_epochs is None: epochs = 1000 else: epochs = max_epochs config = read_config.Config("config_synthetic.yml") if ab is not None: train_size = inference_train_size * ab else: train_size = inference_train_size generator = WakeSleepGen(f"{path}/", batch_size=config.batch_size, train_size=train_size, canvas_shape=config.canvas_shape, max_len=max_len, self_training=True) train_gen = generator.get_train_data() cad_generator = Generator() val_gen = cad_generator.val_gen(batch_size=config.batch_size, path="data/cad/cad.h5", if_augment=False) for parameter in imitate_net.encoder.parameters(): parameter.requires_grad = False optimizer = optim.Adam( [para for para in imitate_net.parameters() if para.requires_grad], weight_decay=config.weight_decay, lr=config.lr) reduce_plat = LearningRate(optimizer, init_lr=config.lr, lr_dacay_fact=0.2, patience=config.patience) best_test_loss = 1e20 torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth") best_test_cd = 1e20 patience = 20 num_worse = 0 for epoch in range(epochs): start = time.time() train_loss = 0 imitate_net.train() for batch_idx in range(train_size // (config.batch_size * config.num_traj)): optimizer.zero_grad() loss = 0 # acc = 0 for _ in range(config.num_traj): data, labels = next(train_gen) # data = data[:, :, 0:1, :, :] one_hot_labels = prepare_input_op(labels, len(generator.unique_draw)) one_hot_labels = torch.from_numpy(one_hot_labels).to(device) data = data.to(device) labels = labels.to(device) outputs = imitate_net([data, one_hot_labels, max_len]) # acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \ # / (labels.shape[0] * labels.shape[1]) / config.num_traj loss_k = ( (losses_joint(outputs, labels, time_steps=max_len + 1) / (max_len + 1)) / config.num_traj) loss_k.backward() loss += float(loss_k) del loss_k optimizer.step() train_loss += loss print(f"batch {batch_idx} train loss: {loss}") # print(f"acc: {acc}") mean_train_loss = train_loss / (train_size // (config.batch_size)) print(f"epoch {epoch} mean train loss: {mean_train_loss}") imitate_net.eval() loss = 0 # acc = 0 metrics = {"cos": 0, "iou": 0, "cd": 0} # IOU = 0 # COS = 0 CD = 0 # correct_programs = 0 # pred_programs = 0 for batch_idx in range(inference_test_size // config.batch_size): parser = ParseModelOutput(generator.unique_draw, max_len // 2 + 1, max_len, config.canvas_shape) with torch.no_grad(): labels = np.zeros((config.batch_size, max_len), dtype=np.int32) data_ = next(val_gen) one_hot_labels = prepare_input_op(labels, len(generator.unique_draw)) one_hot_labels = torch.from_numpy(one_hot_labels).cuda() data = torch.from_numpy(data_).cuda() # outputs = imitate_net([data, one_hot_labels, max_len]) # loss_k = (losses_joint(outputs, labels, time_steps=max_len + 1) / # (max_len + 1)) # loss += float(loss_k) test_outputs = imitate_net.test( [data[-1, :, 0, :, :], one_hot_labels, max_len]) # acc += float((torch.argmax(torch.stack(test_outputs), dim=2).permute(1, 0) == labels[:, :-1]).float().sum()) \ # / (len(labels) * (max_len+1)) / (inference_test_size // config.batch_size) pred_images, correct_prog, pred_prog = parser.get_final_canvas( test_outputs, if_just_expressions=False, if_pred_images=True) # correct_programs += len(correct_prog) # pred_programs += len(pred_prog) target_images = data_[-1, :, 0, :, :].astype(dtype=bool) # iou = np.sum(np.logical_and(target_images, pred_images), # (1, 2)) / \ # np.sum(np.logical_or(target_images, pred_images), # (1, 2)) # cos = cosine_similarity(target_images, pred_images) CD += np.sum(chamfer(target_images, pred_images)) # IOU += np.sum(iou) # COS += np.sum(cos) # metrics["iou"] = IOU / inference_test_size # metrics["cos"] = COS / inference_test_size metrics["cd"] = CD / inference_test_size test_losses = loss test_loss = test_losses / (inference_test_size // (config.batch_size)) if metrics["cd"] >= best_test_cd: num_worse += 1 else: num_worse = 0 best_test_cd = metrics["cd"] torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth") if num_worse >= patience: # load the best model and stop training imitate_net.load_state_dict(torch.load(f"{path}/best_dict.pth")) return epoch + 1 # reduce_plat.reduce_on_plateu(metrics["cd"]) print( f"Epoch {epoch}/100 => train_loss: {mean_train_loss}, iou: {0}, cd: {metrics['cd']}, test_mse: {test_loss}, test_acc: {0}" ) # print(f"CORRECT PROGRAMS: {correct_programs}") # print(f"PREDICTED PROGRAMS: {pred_programs}") # print(f"RATIO: {correct_programs/pred_programs}") end = time.time() print(f"Inference train time {end-start}") del test_losses, outputs, test_outputs return epochs
print('fetch data cost ' + str(time.time() - tick) + 'sec') tick = time.time() data = data[:, :, 0:config.top_k + 1, :, :, :] one_hot_labels = prepare_input_op(labels, len(generator.unique_draw)) one_hot_labels = Variable( torch.from_numpy(one_hot_labels)).cuda() data = Variable(torch.from_numpy(data)).cuda() labels = Variable(torch.from_numpy(labels)).cuda() data = data.permute(1, 0, 2, 3, 4, 5) # forward pass outputs = imitate_net([data, one_hot_labels, k]) loss = losses_joint(outputs, labels, time_steps=k + 1) / types_prog / \ num_accums loss.backward() loss_sum += loss.data print('train one batch cost' + str(time.time() - tick) + 'sec') # Clip the gradient to fixed value to stabilize training. torch.nn.utils.clip_grad_norm(imitate_net.parameters(), 20) optimizer.step() l = loss_sum train_loss += l log_value( 'train_loss_batch', l.cpu().numpy(), epoch * (config.train_size // (config.batch_size * num_accums)) +
loss = Variable(torch.zeros(1)).cuda().data acc = 0 for _ in range(config.num_traj): for k in dataset_sizes.keys(): data, labels = next(train_gen_objs[k]) data = data[:, :, 0:1, :, :] one_hot_labels = prepare_input_op(labels, len(generator.unique_draw)) one_hot_labels = Variable( torch.from_numpy(one_hot_labels)).cuda() data = Variable(torch.from_numpy(data)).cuda() labels = Variable(torch.from_numpy(labels)).cuda() outputs = imitate_net([data, one_hot_labels, k]) #acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \ # / (labels.shape[0] * labels.shape[1]) / types_prog / config.num_traj loss_k = (losses_joint(outputs, labels, time_steps=k + 1) / (k + 1)) / len( dataset_sizes.keys()) / config.num_traj loss_k.backward() loss += loss_k.data del loss_k optimizer.step() train_loss += loss print(f"batch {batch_idx} train loss: {loss.cpu().numpy()}") print(f"acc: {acc}") mean_train_loss = train_loss / (config.train_size // (config.batch_size)) print(f"epoch {epoch} mean train loss: {mean_train_loss.cpu().numpy()}") imitate_net.eval() loss = Variable(torch.zeros(1)).cuda()
print('fetch data cost ' + str(time.time() - tick) + 'sec') tick = time.time() data = data[:, :, 0:config.top_k + 1, :, :] one_hot_labels = prepare_input_op(labels, len(generator.unique_draw)) one_hot_labels = Variable( torch.from_numpy(one_hot_labels)).cuda() data = Variable(torch.from_numpy(data)).cuda() labels = Variable(torch.from_numpy(labels)).cuda() data = data.permute(1, 0, 2, 3, 4) # forward pass outputs = imitate_net([data, one_hot_labels, k]) loss = losses_joint(outputs, labels, time_steps=k + 1) / types_prog / \ num_accums loss.backward() loss_sum += loss.data print('train one batch cost' + str(time.time() - tick) + 'sec') # Clip the gradient to fixed value to stabilize training. torch.nn.utils.clip_grad_norm(imitate_net.parameters(), 20) optimizer.step() l = loss_sum train_loss += l log_value( 'train_loss_batch', l.cpu().numpy(), epoch * (config.train_size // (config.batch_size * num_accums)) +
def train_inference(inference_net, iter): config = read_config.Config("config_synthetic.yml") generator = WakeSleepGen( f"wake_sleep_data/inference/{iter}/labels/labels.pt", f"wake_sleep_data/inference/{iter}/labels/val/labels.pt", batch_size=config.batch_size, train_size=inference_train_size, test_size=inference_test_size, canvas_shape=config.canvas_shape, max_len=max_len) train_gen = generator.get_train_data() test_gen = generator.get_test_data() encoder_net, imitate_net = inference_net optimizer = optim.Adam( [para for para in imitate_net.parameters() if para.requires_grad], weight_decay=config.weight_decay, lr=config.lr) reduce_plat = LearningRate(optimizer, init_lr=config.lr, lr_dacay_fact=0.2, patience=config.patience) best_test_loss = 1e20 best_imitate_dict = imitate_net.state_dict() prev_test_cd = 1e20 prev_test_iou = 0 patience = 5 num_worse = 0 for epoch in range(50): train_loss = 0 Accuracies = [] imitate_net.train() for batch_idx in range(inference_train_size // (config.batch_size * config.num_traj)): optimizer.zero_grad() loss = Variable(torch.zeros(1)).to(device).data for _ in range(config.num_traj): batch_data, batch_labels = next(train_gen) batch_data = batch_data.to(device) batch_labels = batch_labels.to(device) batch_data = batch_data[:, :, 0:1, :, :] one_hot_labels = prepare_input_op(batch_labels, vocab_size) one_hot_labels = Variable( torch.from_numpy(one_hot_labels)).to(device) outputs = imitate_net([batch_data, one_hot_labels, max_len]) loss_k = (losses_joint( outputs, batch_labels, time_steps=max_len + 1) / (max_len + 1)) / config.num_traj loss_k.backward() loss += loss_k.data del loss_k optimizer.step() train_loss += loss print(f"batch {batch_idx} train loss: {loss.cpu().numpy()}") mean_train_loss = train_loss / (inference_train_size // (config.batch_size)) print( f"epoch {epoch} mean train loss: {mean_train_loss.cpu().numpy()}") imitate_net.eval() loss = Variable(torch.zeros(1)).to(device) metrics = {"cos": 0, "iou": 0, "cd": 0} IOU = 0 COS = 0 CD = 0 for batch_idx in range(inference_test_size // config.batch_size): with torch.no_grad(): batch_data, batch_labels = next(test_gen) batch_data = batch_data.to(device) batch_labels = batch_labels.to(device) one_hot_labels = prepare_input_op(batch_labels, vocab_size) one_hot_labels = Variable( torch.from_numpy(one_hot_labels)).to(device) test_outputs = imitate_net( [batch_data, one_hot_labels, max_len]) loss += (losses_joint( test_outputs, batch_labels, time_steps=max_len + 1) / (max_len + 1)) test_output = imitate_net.test( [batch_data, one_hot_labels, max_len]) pred_images, correct_prog, pred_prog = generator.parser.get_final_canvas( test_output, if_just_expressions=False, if_pred_images=True) target_images = batch_data.cpu().numpy()[-1, :, 0, :, :].astype( dtype=bool) iou = np.sum(np.logical_and(target_images, pred_images), (1, 2)) / \ np.sum(np.logical_or(target_images, pred_images), (1, 2)) cos = cosine_similarity(target_images, pred_images) CD += np.sum(chamfer(target_images, pred_images)) IOU += np.sum(iou) COS += np.sum(cos) metrics["iou"] = IOU / inference_test_size metrics["cos"] = COS / inference_test_size metrics["cd"] = CD / inference_test_size test_losses = loss.data test_loss = test_losses.cpu().numpy() / (inference_test_size // (config.batch_size)) if test_loss >= best_test_loss: num_worse += 1 else: num_worse = 0 best_test_loss = test_loss best_imitate_dict = imitate_net.state_dict() if num_worse >= patience: # load the best model and stop training imitate_net.load_state_dict(best_imitate_dict) break reduce_plat.reduce_on_plateu(metrics["cd"]) print("Epoch {}/{}=> train_loss: {}, iou: {}, cd: {}, test_mse: {}". format( epoch, config.epochs, mean_train_loss.cpu().numpy(), metrics["iou"], metrics["cd"], test_loss, )) print(f"CORRECT PROGRAMS: {len(generator.correct_programs)}") del test_losses, test_outputs
perturbs = torch.from_numpy(perturbs).to(device) perturb_out = perturb_out.permute(1, 0, 2) # mask off ops and stop token perturb_loss = F.mse_loss( perturbs[labels < 396], perturb_out[labels < 396]) / len( dataset_sizes.keys()) / config.num_traj #perturb_loss = F.mse_loss(perturbs, perturb_out) / len(dataset_sizes.keys()) / config.num_traj if not imitate_net.tf: acc += float((torch.argmax(torch.stack(outputs), dim=2).permute(1, 0) == labels).float().sum()) \ / (labels.shape[0] * labels.shape[1]) / types_prog / config.num_traj else: acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \ / (labels.shape[0] * labels.shape[1]) / types_prog / config.num_traj loss_k_token = ( (losses_joint(outputs, labels, time_steps=k + 1) / (k + 1)) / len(dataset_sizes.keys()) / config.num_traj) #loss_k = loss_k_token + perturb_loss loss_k = loss_k_token loss_k.backward() loss += loss_k.data loss_p += perturb_loss.data loss_t += loss_k_token.data del loss_k optimizer.step() train_loss += loss print( f"batch {batch_idx} train loss: {loss.cpu().numpy()}, token loss: {loss_t.cpu().numpy()}, perturb loss: {loss_p.cpu().numpy()}" ) print(f"acc: {acc}")
def train_model(csgnet, train_dataset, val_dataset, max_epochs=None): if max_epochs is None: epochs = 100 else: epochs = max_epochs optimizer = optim.Adam( [para for para in csgnet.parameters() if para.requires_grad], weight_decay=config.weight_decay, lr=config.lr) reduce_plat = LearningRate(optimizer, init_lr=config.lr, lr_dacay_fact=0.2, lr_decay_epoch=3, patience=config.patience) best_state_dict = None patience = 3 prev_test_loss = 1e20 prev_test_reward = 0 num_worse = 0 for epoch in range(100): train_loss = 0 Accuracies = [] csgnet.train() # Number of times to accumulate gradients num_accums = config.num_traj batch_idx = 0 count = 0 for batch in train_dataset: labels = np.stack([x[0] for x in batch]) data = np.stack([x[1] for x in batch]) if not len(labels) == config.batch_size: continue optimizer.zero_grad() loss_sum = Variable(torch.zeros(1)).cuda().data one_hot_labels = prepare_input_op(labels, len(unique_draws)) one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda() data = Variable( torch.from_numpy(data)).cuda().unsqueeze(-1).float() labels = Variable(torch.from_numpy(labels)).cuda() # forward pass outputs = csgnet.forward2([data, one_hot_labels, max_len]) loss = losses_joint(outputs, labels, time_steps=max_len + 1) / num_accums loss.backward() loss_sum += loss.data batch_idx += 1 count += len(data) if batch_idx % num_accums == 0: # Clip the gradient to fixed value to stabilize training. torch.nn.utils.clip_grad_norm_(csgnet.parameters(), 20) optimizer.step() l = loss_sum train_loss += l # print(f'train loss batch {batch_idx}: {l}') mean_train_loss = (train_loss * num_accums) / (count // config.batch_size) print(f'train loss epoch {epoch}: {float(mean_train_loss)}') del data, loss, loss_sum, train_loss, outputs test_losses = 0 acc = 0 csgnet.eval() test_reward = 0 batch_idx = 0 count = 0 for batch in val_dataset: labels = np.stack([x[0] for x in batch]) data = np.stack([x[1] for x in batch]) if not len(labels) == config.batch_size: continue parser = ParseModelOutput(unique_draws, stack_size=(max_len + 1) // 2 + 1, steps=max_len, canvas_shape=[64, 64, 64], primitives=primitives) with torch.no_grad(): one_hot_labels = prepare_input_op(labels, len(unique_draws)) one_hot_labels = Variable( torch.from_numpy(one_hot_labels)).cuda() data = Variable( torch.from_numpy(data)).cuda().unsqueeze(-1).float() labels = Variable(torch.from_numpy(labels)).cuda() test_output = csgnet.forward2([data, one_hot_labels, max_len]) l = losses_joint(test_output, labels, time_steps=max_len + 1).data test_losses += l acc += float((torch.argmax(torch.stack(test_output), dim=2).permute(1, 0) == labels).float().sum()) \ / (labels.shape[0] * labels.shape[1]) test_output = csgnet.test2(data, max_len) stack, _, _ = parser.get_final_canvas( test_output, if_pred_images=True, if_just_expressions=False) data_ = data.squeeze().cpu().numpy() R = np.sum(np.logical_and(stack, data_), (1, 2, 3)) / (np.sum(np.logical_or(stack, data_), (1, 2, 3)) + 1) test_reward += np.sum(R) batch_idx += 1 count += len(data) test_reward = test_reward / count test_loss = test_losses / (count // config.batch_size) acc = acc / (count // config.batch_size) if test_loss < prev_test_loss: prev_test_loss = test_loss best_state_dict = csgnet.state_dict() num_worse = 0 else: num_worse += 1 if num_worse >= patience: csgnet.load_state_dict(best_state_dict) break print(f'test loss epoch {epoch}: {float(test_loss)}') print(f'test IOU epoch {epoch}: {test_reward}') print(f'test acc epoch {epoch}: {acc}') if config.if_schedule: reduce_plat.reduce_on_plateu(-test_reward) del test_losses, test_output if test_reward > prev_test_reward: prev_test_reward = test_reward