#include <stdlib.h>
#include <string.h>

#include "base64.h"

#define BLOCK_ENC_SIZE 4
#define BLOCK_DEC_SIZE 3

static const char enc_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

static void encode_block(char dst[static BLOCK_ENC_SIZE], const uint8_t src[static BLOCK_DEC_SIZE], size_t size)
{
	uint8_t a = src[0] >> 2;
	uint8_t b = ((src[1] & 0xF0) >> 4) | ((src[0] & 0x03) << 4);
	uint8_t c = ((src[2] & 0xC0) >> 6) | ((src[1] & 0x0F) << 2);
	uint8_t d = src[2] & 0x3F;

	dst[0] = enc_table[a];
	dst[1] = enc_table[b];
	dst[2] = size > 1 ? enc_table[c] : '=';
	dst[3] = size > 2 ? enc_table[d] : '=';
}

char *base64_encode(const uint8_t *data, size_t size)
{
	size_t num_blocks = size / BLOCK_DEC_SIZE;
	if (size % BLOCK_DEC_SIZE != 0) {
		num_blocks++;
	}

	char *out = malloc(BLOCK_ENC_SIZE * num_blocks + 1);
	if (!out)
		return NULL;

	size_t i = 0, j = 0;
	while (i < size) {
		uint8_t block[BLOCK_DEC_SIZE] = {0};
		size_t block_size = 0;
		while (i < size && block_size < BLOCK_DEC_SIZE) {
			block[block_size++] = data[i++];
		}

		encode_block(&out[j], block, block_size);
		j += BLOCK_ENC_SIZE;
	}

	out[j] = '\0';
	return out;
}

static int dec_table_get(const uint8_t *dec_table, char ch)
{
	uint8_t b = dec_table[(size_t)ch];
	if (b == 0 && ch != enc_table[0])
		return -1;
	return b;
}

static bool decode_block(uint8_t dst[static BLOCK_DEC_SIZE], const char src[static BLOCK_ENC_SIZE],
		size_t size, const uint8_t *dec_table)
{
	int a = dec_table_get(dec_table, src[0]);
	int b = dec_table_get(dec_table, src[1]);
	int c = size > 2 ? dec_table_get(dec_table, src[2]) : 0;
	int d = size > 3 ? dec_table_get(dec_table, src[3]) : 0;
	if (a < 0 || b < 0 || c < 0 || d < 0)
		return false;

	dst[0] = (a << 2) | (b >> 4);
	dst[1] = (b << 4) | (c >> 2);
	dst[2] = (c << 6) | d;
	return true;
}

bool base64_decode(const char *in, uint8_t **data_ptr, size_t *size_ptr)
{
	size_t in_len = strlen(in);
	if (in_len == 0) {
		*data_ptr = NULL;
		*size_ptr = 0;
		return true;
	}

	if (in_len % BLOCK_ENC_SIZE != 0)
		return false;
	size_t num_blocks = in_len / BLOCK_ENC_SIZE;

	size_t padding_len = 0;
	while (in[in_len - padding_len - 1] == '=')
		padding_len++;
	if (padding_len >= BLOCK_DEC_SIZE)
		return false;

	uint8_t dec_table[255] = {0};
	for (uint8_t i = 0; i < sizeof(enc_table) / sizeof(enc_table[0]); i++)
		dec_table[(size_t)enc_table[i]] = i;

	size_t size = num_blocks * BLOCK_DEC_SIZE - padding_len;
	uint8_t *data = malloc(size);
	if (data == NULL)
		return false;

	size_t i = 0, j = 0;
	while (i < in_len) {
		size_t block_padding_len = 0;
		if (i + BLOCK_ENC_SIZE == in_len)
			block_padding_len = padding_len;

		uint8_t block[BLOCK_DEC_SIZE] = {0};
		if (!decode_block(block, &in[i], BLOCK_ENC_SIZE - block_padding_len, dec_table)) {
			free(data);
			return false;
		}

		memcpy(&data[j], block, BLOCK_DEC_SIZE - block_padding_len);
		i += BLOCK_ENC_SIZE;
		j += BLOCK_DEC_SIZE;
	}

	*data_ptr = data;
	*size_ptr = size;
	return true;
}
